From 996bcbdac6f23b0e72374411489e4daf27fab189 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 14 Sep 2022 20:16:23 +0200 Subject: [PATCH] Make EntityComponent generic (#78473) --- .../components/automation/__init__.py | 28 ++++++++++--------- homeassistant/components/counter/__init__.py | 2 +- homeassistant/components/dominos/__init__.py | 4 +-- .../components/image_processing/__init__.py | 4 ++- .../components/input_boolean/__init__.py | 2 +- .../components/input_button/__init__.py | 2 +- .../components/input_datetime/__init__.py | 2 +- .../components/input_number/__init__.py | 2 +- .../components/input_select/__init__.py | 2 +- .../components/input_text/__init__.py | 2 +- homeassistant/components/mailbox/__init__.py | 2 +- homeassistant/components/person/__init__.py | 2 +- homeassistant/components/plant/__init__.py | 2 +- .../components/remember_the_milk/__init__.py | 2 +- homeassistant/components/rest/__init__.py | 3 +- homeassistant/components/schedule/__init__.py | 2 +- homeassistant/components/script/__init__.py | 14 +++------- homeassistant/components/timer/__init__.py | 2 +- homeassistant/components/zone/__init__.py | 2 +- homeassistant/helpers/entity_component.py | 19 +++++++------ homeassistant/helpers/reload.py | 3 +- 21 files changed, 53 insertions(+), 50 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 454edce5cac..b1ec0c68e4b 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -154,12 +154,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if entity_id in cast(AutomationEntity, automation_entity).referenced_entities + if entity_id in automation_entity.referenced_entities ] @@ -169,12 +169,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(cast(AutomationEntity, automation_entity).referenced_entities) + return list(automation_entity.referenced_entities) @callback @@ -183,12 +183,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if device_id in cast(AutomationEntity, automation_entity).referenced_devices + if device_id in automation_entity.referenced_devices ] @@ -198,12 +198,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(cast(AutomationEntity, automation_entity).referenced_devices) + return list(automation_entity.referenced_devices) @callback @@ -212,12 +212,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] return [ automation_entity.entity_id for automation_entity in component.entities - if area_id in cast(AutomationEntity, automation_entity).referenced_areas + if area_id in automation_entity.referenced_areas ] @@ -227,12 +227,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] if (automation_entity := component.get_entity(entity_id)) is None: return [] - return list(cast(AutomationEntity, automation_entity).referenced_areas) + return list(automation_entity.referenced_areas) @callback @@ -252,7 +252,9 @@ def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up all automations.""" - hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) + hass.data[DOMAIN] = component = EntityComponent[AutomationEntity]( + LOGGER, DOMAIN, hass + ) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 113826c2291..dedeb428c0c 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -93,7 +93,7 @@ CONFIG_SCHEMA = vol.Schema( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the counters.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[Counter](_LOGGER, DOMAIN, hass) id_manager = collection.IDManager() yaml_collection = collection.YamlCollection( diff --git a/homeassistant/components/dominos/__init__.py b/homeassistant/components/dominos/__init__.py index 31feaa7687e..ecfe7b65a7d 100644 --- a/homeassistant/components/dominos/__init__.py +++ b/homeassistant/components/dominos/__init__.py @@ -67,9 +67,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up is called when Home Assistant is loading our component.""" dominos = Dominos(hass, config) - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[DominosOrder](_LOGGER, DOMAIN, hass) hass.data[DOMAIN] = {} - entities = [] + entities: list[DominosOrder] = [] conf = config[DOMAIN] hass.services.register( diff --git a/homeassistant/components/image_processing/__init__.py b/homeassistant/components/image_processing/__init__.py index 26e6d195b92..8987a366aee 100644 --- a/homeassistant/components/image_processing/__init__.py +++ b/homeassistant/components/image_processing/__init__.py @@ -85,7 +85,9 @@ class FaceInformation(TypedDict, total=False): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the image processing.""" - component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL) + component = EntityComponent[ImageProcessingEntity]( + _LOGGER, DOMAIN, hass, SCAN_INTERVAL + ) await component.async_setup(config) diff --git a/homeassistant/components/input_boolean/__init__.py b/homeassistant/components/input_boolean/__init__.py index f1cdb145a7c..d1d19247121 100644 --- a/homeassistant/components/input_boolean/__init__.py +++ b/homeassistant/components/input_boolean/__init__.py @@ -92,7 +92,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input boolean.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputBoolean](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/input_button/__init__.py b/homeassistant/components/input_button/__init__.py index 14ff940ff64..f425c8e3da2 100644 --- a/homeassistant/components/input_button/__init__.py +++ b/homeassistant/components/input_button/__init__.py @@ -77,7 +77,7 @@ class InputButtonStorageCollection(collection.StorageCollection): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input button.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputButton](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/input_datetime/__init__.py b/homeassistant/components/input_datetime/__init__.py index afd94ea60f4..5913789d53f 100644 --- a/homeassistant/components/input_datetime/__init__.py +++ b/homeassistant/components/input_datetime/__init__.py @@ -130,7 +130,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({}) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input datetime.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputDatetime](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index 3a7f7b29f13..99e54dc9baa 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -107,7 +107,7 @@ STORAGE_VERSION = 1 async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input slider.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputNumber](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/input_select/__init__.py b/homeassistant/components/input_select/__init__.py index 41b079f0888..f30b2ca1e36 100644 --- a/homeassistant/components/input_select/__init__.py +++ b/homeassistant/components/input_select/__init__.py @@ -132,7 +132,7 @@ class InputSelectStore(Store): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input select.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputSelect](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/input_text/__init__.py b/homeassistant/components/input_text/__init__.py index 072f17c72a3..6069ae8143a 100644 --- a/homeassistant/components/input_text/__init__.py +++ b/homeassistant/components/input_text/__init__.py @@ -107,7 +107,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({}) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input text.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[InputText](_LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/mailbox/__init__.py b/homeassistant/components/mailbox/__init__.py index 4e65d989b98..f97b2c5337b 100644 --- a/homeassistant/components/mailbox/__init__.py +++ b/homeassistant/components/mailbox/__init__.py @@ -83,7 +83,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: mailboxes.append(mailbox) mailbox_entity = MailboxEntity(mailbox) - component = EntityComponent( + component = EntityComponent[MailboxEntity]( logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL ) await component.async_add_entities([mailbox_entity]) diff --git a/homeassistant/components/person/__init__.py b/homeassistant/components/person/__init__.py index c41be68d6ea..0823e9e4b55 100644 --- a/homeassistant/components/person/__init__.py +++ b/homeassistant/components/person/__init__.py @@ -326,7 +326,7 @@ The following persons point at invalid users: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the person component.""" - entity_component = EntityComponent(_LOGGER, DOMAIN, hass) + entity_component = EntityComponent[Person](_LOGGER, DOMAIN, hass) id_manager = collection.IDManager() yaml_collection = collection.YamlCollection( logging.getLogger(f"{__name__}.yaml_collection"), id_manager diff --git a/homeassistant/components/plant/__init__.py b/homeassistant/components/plant/__init__.py index 0d95ccbc300..69f440b6859 100644 --- a/homeassistant/components/plant/__init__.py +++ b/homeassistant/components/plant/__init__.py @@ -111,7 +111,7 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: {cv.string: PLANT_SCHEMA}}, extra=vol.ALLOW_ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Plant component.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[Plant](_LOGGER, DOMAIN, hass) entities = [] for plant_name, plant_config in config[DOMAIN].items(): diff --git a/homeassistant/components/remember_the_milk/__init__.py b/homeassistant/components/remember_the_milk/__init__.py index fbc2518ce1b..3331f9c61d4 100644 --- a/homeassistant/components/remember_the_milk/__init__.py +++ b/homeassistant/components/remember_the_milk/__init__.py @@ -52,7 +52,7 @@ SERVICE_SCHEMA_COMPLETE_TASK = vol.Schema({vol.Required(CONF_ID): cv.string}) def setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Remember the milk component.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[RememberTheMilk](_LOGGER, DOMAIN, hass) stored_rtm_config = RememberTheMilkConfiguration(hass) for rtm_config in config[DOMAIN]: diff --git a/homeassistant/components/rest/__init__.py b/homeassistant/components/rest/__init__.py index f8e6941572a..282f05aada8 100644 --- a/homeassistant/components/rest/__init__.py +++ b/homeassistant/components/rest/__init__.py @@ -27,6 +27,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers import discovery, template +from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import ( DEFAULT_SCAN_INTERVAL, EntityComponent, @@ -53,7 +54,7 @@ COORDINATOR_AWARE_PLATFORMS = [SENSOR_DOMAIN, BINARY_SENSOR_DOMAIN] async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the rest platforms.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[Entity](_LOGGER, DOMAIN, hass) _async_setup_shared_data(hass) async def reload_service_handler(service: ServiceCall) -> None: diff --git a/homeassistant/components/schedule/__init__.py b/homeassistant/components/schedule/__init__.py index 394e2ae3c36..fefb5189e3c 100644 --- a/homeassistant/components/schedule/__init__.py +++ b/homeassistant/components/schedule/__init__.py @@ -154,7 +154,7 @@ ENTITY_SCHEMA = vol.Schema( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input select.""" - component = EntityComponent(LOGGER, DOMAIN, hass) + component = EntityComponent[Schedule](LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 53bd256c624..a5ea2a17e0b 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -183,7 +183,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Load the scripts from the configuration.""" - hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) + hass.data[DOMAIN] = component = EntityComponent[ScriptEntity](LOGGER, DOMAIN, hass) # Process integration platforms right away since # we will create entities before firing EVENT_COMPONENT_LOADED @@ -205,9 +205,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def turn_on_service(service: ServiceCall) -> None: """Call a service to turn script on.""" variables = service.data.get(ATTR_VARIABLES) - script_entities: list[ScriptEntity] = cast( - list[ScriptEntity], await component.async_extract_from_service(service) - ) + script_entities = await component.async_extract_from_service(service) for script_entity in script_entities: await script_entity.async_turn_on( variables=variables, context=service.context, wait=False @@ -216,9 +214,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def turn_off_service(service: ServiceCall) -> None: """Cancel a script.""" # Stopping a script is ok to be done in parallel - script_entities: list[ScriptEntity] = cast( - list[ScriptEntity], await component.async_extract_from_service(service) - ) + script_entities = await component.async_extract_from_service(service) if not script_entities: return @@ -232,9 +228,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def toggle_service(service: ServiceCall) -> None: """Toggle a script.""" - script_entities: list[ScriptEntity] = cast( - list[ScriptEntity], await component.async_extract_from_service(service) - ) + script_entities = await component.async_extract_from_service(service) for script_entity in script_entities: await script_entity.async_toggle(context=service.context, wait=False) diff --git a/homeassistant/components/timer/__init__.py b/homeassistant/components/timer/__init__.py index ff50e96a18c..6b141ccea4c 100644 --- a/homeassistant/components/timer/__init__.py +++ b/homeassistant/components/timer/__init__.py @@ -106,7 +106,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({}) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an input select.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + component = EntityComponent[Timer](_LOGGER, DOMAIN, hass) id_manager = collection.IDManager() yaml_collection = collection.YamlCollection( diff --git a/homeassistant/components/zone/__init__.py b/homeassistant/components/zone/__init__.py index aa910a7789e..816d8a62f18 100644 --- a/homeassistant/components/zone/__init__.py +++ b/homeassistant/components/zone/__init__.py @@ -185,7 +185,7 @@ class ZoneStorageCollection(collection.StorageCollection): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up configured zones as well as Home Assistant zone if necessary.""" - component = entity_component.EntityComponent(_LOGGER, DOMAIN, hass) + component = entity_component.EntityComponent[Zone](_LOGGER, DOMAIN, hass) id_manager = collection.IDManager() yaml_collection = collection.IDLessCollection( diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 1cef123b292..fb627820060 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -7,7 +7,7 @@ from datetime import timedelta from itertools import chain import logging from types import ModuleType -from typing import Any +from typing import Any, Generic, TypeVar import voluptuous as vol @@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType DEFAULT_SCAN_INTERVAL = timedelta(seconds=15) DATA_INSTANCES = "entity_components" +_EntityT = TypeVar("_EntityT", bound=entity.Entity) + @bind_hass async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None: @@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None: await entity_obj.async_update_ha_state(True) -class EntityComponent: +class EntityComponent(Generic[_EntityT]): """The EntityComponent manages platforms that manages entities. This class has the following responsibilities: @@ -86,18 +88,19 @@ class EntityComponent: hass.data.setdefault(DATA_INSTANCES, {})[domain] = self @property - def entities(self) -> Iterable[entity.Entity]: + def entities(self) -> Iterable[_EntityT]: """Return an iterable that returns all entities.""" return chain.from_iterable( - platform.entities.values() for platform in self._platforms.values() + platform.entities.values() # type: ignore[misc] + for platform in self._platforms.values() ) - def get_entity(self, entity_id: str) -> entity.Entity | None: + def get_entity(self, entity_id: str) -> _EntityT | None: """Get an entity.""" for platform in self._platforms.values(): entity_obj = platform.entities.get(entity_id) if entity_obj is not None: - return entity_obj + return entity_obj # type: ignore[return-value] return None def setup(self, config: ConfigType) -> None: @@ -176,14 +179,14 @@ class EntityComponent: async def async_extract_from_service( self, service_call: ServiceCall, expand_group: bool = True - ) -> list[entity.Entity]: + ) -> list[_EntityT]: """Extract all known and available entities from a service call. Will return an empty list if entities specified but unknown. This method must be run in the event loop. """ - return await service.async_extract_entities( + return await service.async_extract_entities( # type: ignore[return-value] self.hass, self.entities, service_call, expand_group ) diff --git a/homeassistant/helpers/reload.py b/homeassistant/helpers/reload.py index 83698557eb6..75529476dd2 100644 --- a/homeassistant/helpers/reload.py +++ b/homeassistant/helpers/reload.py @@ -14,6 +14,7 @@ from homeassistant.loader import async_get_integration from homeassistant.setup import async_setup_component from . import config_per_platform +from .entity import Entity from .entity_component import EntityComponent from .entity_platform import EntityPlatform, async_get_platforms from .service import async_register_admin_service @@ -120,7 +121,7 @@ async def _async_setup_platform( ) return - entity_component: EntityComponent = hass.data[integration_platform] + entity_component: EntityComponent[Entity] = hass.data[integration_platform] tasks = [ entity_component.async_setup_platform(integration_name, p_config) for p_config in platform_configs