Make EntityComponent generic (#78473)
This commit is contained in:
parent
fd05d949cc
commit
996bcbdac6
21 changed files with 53 additions and 50 deletions
|
@ -154,12 +154,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if entity_id in cast(AutomationEntity, automation_entity).referenced_entities
|
||||
if entity_id in automation_entity.referenced_entities
|
||||
]
|
||||
|
||||
|
||||
|
@ -169,12 +169,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_entities)
|
||||
return list(automation_entity.referenced_entities)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -183,12 +183,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if device_id in cast(AutomationEntity, automation_entity).referenced_devices
|
||||
if device_id in automation_entity.referenced_devices
|
||||
]
|
||||
|
||||
|
||||
|
@ -198,12 +198,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_devices)
|
||||
return list(automation_entity.referenced_devices)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -212,12 +212,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if area_id in cast(AutomationEntity, automation_entity).referenced_areas
|
||||
if area_id in automation_entity.referenced_areas
|
||||
]
|
||||
|
||||
|
||||
|
@ -227,12 +227,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_areas)
|
||||
return list(automation_entity.referenced_areas)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -252,7 +252,9 @@ def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up all automations."""
|
||||
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
|
||||
hass.data[DOMAIN] = component = EntityComponent[AutomationEntity](
|
||||
LOGGER, DOMAIN, hass
|
||||
)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -93,7 +93,7 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the counters."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[Counter](_LOGGER, DOMAIN, hass)
|
||||
id_manager = collection.IDManager()
|
||||
|
||||
yaml_collection = collection.YamlCollection(
|
||||
|
|
|
@ -67,9 +67,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
"""Set up is called when Home Assistant is loading our component."""
|
||||
dominos = Dominos(hass, config)
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[DominosOrder](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DOMAIN] = {}
|
||||
entities = []
|
||||
entities: list[DominosOrder] = []
|
||||
conf = config[DOMAIN]
|
||||
|
||||
hass.services.register(
|
||||
|
|
|
@ -85,7 +85,9 @@ class FaceInformation(TypedDict, total=False):
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the image processing."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL)
|
||||
component = EntityComponent[ImageProcessingEntity](
|
||||
_LOGGER, DOMAIN, hass, SCAN_INTERVAL
|
||||
)
|
||||
|
||||
await component.async_setup(config)
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input boolean."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputBoolean](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -77,7 +77,7 @@ class InputButtonStorageCollection(collection.StorageCollection):
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input button."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputButton](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -130,7 +130,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input datetime."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputDatetime](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -107,7 +107,7 @@ STORAGE_VERSION = 1
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input slider."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputNumber](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -132,7 +132,7 @@ class InputSelectStore(Store):
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input select."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputSelect](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -107,7 +107,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input text."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[InputText](_LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -83,7 +83,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
mailboxes.append(mailbox)
|
||||
mailbox_entity = MailboxEntity(mailbox)
|
||||
component = EntityComponent(
|
||||
component = EntityComponent[MailboxEntity](
|
||||
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL
|
||||
)
|
||||
await component.async_add_entities([mailbox_entity])
|
||||
|
|
|
@ -326,7 +326,7 @@ The following persons point at invalid users:
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the person component."""
|
||||
entity_component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entity_component = EntityComponent[Person](_LOGGER, DOMAIN, hass)
|
||||
id_manager = collection.IDManager()
|
||||
yaml_collection = collection.YamlCollection(
|
||||
logging.getLogger(f"{__name__}.yaml_collection"), id_manager
|
||||
|
|
|
@ -111,7 +111,7 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: {cv.string: PLANT_SCHEMA}}, extra=vol.ALLOW_
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Plant component."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[Plant](_LOGGER, DOMAIN, hass)
|
||||
|
||||
entities = []
|
||||
for plant_name, plant_config in config[DOMAIN].items():
|
||||
|
|
|
@ -52,7 +52,7 @@ SERVICE_SCHEMA_COMPLETE_TASK = vol.Schema({vol.Required(CONF_ID): cv.string})
|
|||
|
||||
def setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Remember the milk component."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[RememberTheMilk](_LOGGER, DOMAIN, hass)
|
||||
|
||||
stored_rtm_config = RememberTheMilkConfiguration(hass)
|
||||
for rtm_config in config[DOMAIN]:
|
||||
|
|
|
@ -27,6 +27,7 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.helpers import discovery, template
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import (
|
||||
DEFAULT_SCAN_INTERVAL,
|
||||
EntityComponent,
|
||||
|
@ -53,7 +54,7 @@ COORDINATOR_AWARE_PLATFORMS = [SENSOR_DOMAIN, BINARY_SENSOR_DOMAIN]
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the rest platforms."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[Entity](_LOGGER, DOMAIN, hass)
|
||||
_async_setup_shared_data(hass)
|
||||
|
||||
async def reload_service_handler(service: ServiceCall) -> None:
|
||||
|
|
|
@ -154,7 +154,7 @@ ENTITY_SCHEMA = vol.Schema(
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input select."""
|
||||
component = EntityComponent(LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[Schedule](LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
|
|
@ -183,7 +183,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Load the scripts from the configuration."""
|
||||
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
|
||||
hass.data[DOMAIN] = component = EntityComponent[ScriptEntity](LOGGER, DOMAIN, hass)
|
||||
|
||||
# Process integration platforms right away since
|
||||
# we will create entities before firing EVENT_COMPONENT_LOADED
|
||||
|
@ -205,9 +205,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async def turn_on_service(service: ServiceCall) -> None:
|
||||
"""Call a service to turn script on."""
|
||||
variables = service.data.get(ATTR_VARIABLES)
|
||||
script_entities: list[ScriptEntity] = cast(
|
||||
list[ScriptEntity], await component.async_extract_from_service(service)
|
||||
)
|
||||
script_entities = await component.async_extract_from_service(service)
|
||||
for script_entity in script_entities:
|
||||
await script_entity.async_turn_on(
|
||||
variables=variables, context=service.context, wait=False
|
||||
|
@ -216,9 +214,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async def turn_off_service(service: ServiceCall) -> None:
|
||||
"""Cancel a script."""
|
||||
# Stopping a script is ok to be done in parallel
|
||||
script_entities: list[ScriptEntity] = cast(
|
||||
list[ScriptEntity], await component.async_extract_from_service(service)
|
||||
)
|
||||
script_entities = await component.async_extract_from_service(service)
|
||||
|
||||
if not script_entities:
|
||||
return
|
||||
|
@ -232,9 +228,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
async def toggle_service(service: ServiceCall) -> None:
|
||||
"""Toggle a script."""
|
||||
script_entities: list[ScriptEntity] = cast(
|
||||
list[ScriptEntity], await component.async_extract_from_service(service)
|
||||
)
|
||||
script_entities = await component.async_extract_from_service(service)
|
||||
for script_entity in script_entities:
|
||||
await script_entity.async_toggle(context=service.context, wait=False)
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up an input select."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = EntityComponent[Timer](_LOGGER, DOMAIN, hass)
|
||||
id_manager = collection.IDManager()
|
||||
|
||||
yaml_collection = collection.YamlCollection(
|
||||
|
|
|
@ -185,7 +185,7 @@ class ZoneStorageCollection(collection.StorageCollection):
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up configured zones as well as Home Assistant zone if necessary."""
|
||||
component = entity_component.EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
component = entity_component.EntityComponent[Zone](_LOGGER, DOMAIN, hass)
|
||||
id_manager = collection.IDManager()
|
||||
|
||||
yaml_collection = collection.IDLessCollection(
|
||||
|
|
|
@ -7,7 +7,7 @@ from datetime import timedelta
|
|||
from itertools import chain
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType
|
|||
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
|
||||
DATA_INSTANCES = "entity_components"
|
||||
|
||||
_EntityT = TypeVar("_EntityT", bound=entity.Entity)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
|
||||
|
@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
|
|||
await entity_obj.async_update_ha_state(True)
|
||||
|
||||
|
||||
class EntityComponent:
|
||||
class EntityComponent(Generic[_EntityT]):
|
||||
"""The EntityComponent manages platforms that manages entities.
|
||||
|
||||
This class has the following responsibilities:
|
||||
|
@ -86,18 +88,19 @@ class EntityComponent:
|
|||
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
|
||||
|
||||
@property
|
||||
def entities(self) -> Iterable[entity.Entity]:
|
||||
def entities(self) -> Iterable[_EntityT]:
|
||||
"""Return an iterable that returns all entities."""
|
||||
return chain.from_iterable(
|
||||
platform.entities.values() for platform in self._platforms.values()
|
||||
platform.entities.values() # type: ignore[misc]
|
||||
for platform in self._platforms.values()
|
||||
)
|
||||
|
||||
def get_entity(self, entity_id: str) -> entity.Entity | None:
|
||||
def get_entity(self, entity_id: str) -> _EntityT | None:
|
||||
"""Get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity_obj = platform.entities.get(entity_id)
|
||||
if entity_obj is not None:
|
||||
return entity_obj
|
||||
return entity_obj # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def setup(self, config: ConfigType) -> None:
|
||||
|
@ -176,14 +179,14 @@ class EntityComponent:
|
|||
|
||||
async def async_extract_from_service(
|
||||
self, service_call: ServiceCall, expand_group: bool = True
|
||||
) -> list[entity.Entity]:
|
||||
) -> list[_EntityT]:
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
||||
Will return an empty list if entities specified but unknown.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return await service.async_extract_entities(
|
||||
return await service.async_extract_entities( # type: ignore[return-value]
|
||||
self.hass, self.entities, service_call, expand_group
|
||||
)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from homeassistant.loader import async_get_integration
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import config_per_platform
|
||||
from .entity import Entity
|
||||
from .entity_component import EntityComponent
|
||||
from .entity_platform import EntityPlatform, async_get_platforms
|
||||
from .service import async_register_admin_service
|
||||
|
@ -120,7 +121,7 @@ async def _async_setup_platform(
|
|||
)
|
||||
return
|
||||
|
||||
entity_component: EntityComponent = hass.data[integration_platform]
|
||||
entity_component: EntityComponent[Entity] = hass.data[integration_platform]
|
||||
tasks = [
|
||||
entity_component.async_setup_platform(integration_name, p_config)
|
||||
for p_config in platform_configs
|
||||
|
|
Loading…
Add table
Reference in a new issue