Make EntityComponent generic (#78473)

This commit is contained in:
epenet 2022-09-14 20:16:23 +02:00 committed by GitHub
parent fd05d949cc
commit 996bcbdac6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 53 additions and 50 deletions

View file

@ -7,7 +7,7 @@ from datetime import timedelta
from itertools import chain
import logging
from types import ModuleType
from typing import Any
from typing import Any, Generic, TypeVar
import voluptuous as vol
@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components"
_EntityT = TypeVar("_EntityT", bound=entity.Entity)
@bind_hass
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
await entity_obj.async_update_ha_state(True)
class EntityComponent:
class EntityComponent(Generic[_EntityT]):
"""The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
@ -86,18 +88,19 @@ class EntityComponent:
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
def entities(self) -> Iterable[entity.Entity]:
def entities(self) -> Iterable[_EntityT]:
"""Return an iterable that returns all entities."""
return chain.from_iterable(
platform.entities.values() for platform in self._platforms.values()
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
def get_entity(self, entity_id: str) -> entity.Entity | None:
def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj
return entity_obj # type: ignore[return-value]
return None
def setup(self, config: ConfigType) -> None:
@ -176,14 +179,14 @@ class EntityComponent:
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> list[entity.Entity]:
) -> list[_EntityT]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
return await service.async_extract_entities( # type: ignore[return-value]
self.hass, self.entities, service_call, expand_group
)