Make EntityComponent generic (#78473)
This commit is contained in:
parent
fd05d949cc
commit
996bcbdac6
21 changed files with 53 additions and 50 deletions
|
@ -7,7 +7,7 @@ from datetime import timedelta
|
|||
from itertools import chain
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType
|
|||
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
|
||||
DATA_INSTANCES = "entity_components"
|
||||
|
||||
_EntityT = TypeVar("_EntityT", bound=entity.Entity)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
|
||||
|
@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
|
|||
await entity_obj.async_update_ha_state(True)
|
||||
|
||||
|
||||
class EntityComponent:
|
||||
class EntityComponent(Generic[_EntityT]):
|
||||
"""The EntityComponent manages platforms that manages entities.
|
||||
|
||||
This class has the following responsibilities:
|
||||
|
@ -86,18 +88,19 @@ class EntityComponent:
|
|||
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
|
||||
|
||||
@property
|
||||
def entities(self) -> Iterable[entity.Entity]:
|
||||
def entities(self) -> Iterable[_EntityT]:
|
||||
"""Return an iterable that returns all entities."""
|
||||
return chain.from_iterable(
|
||||
platform.entities.values() for platform in self._platforms.values()
|
||||
platform.entities.values() # type: ignore[misc]
|
||||
for platform in self._platforms.values()
|
||||
)
|
||||
|
||||
def get_entity(self, entity_id: str) -> entity.Entity | None:
|
||||
def get_entity(self, entity_id: str) -> _EntityT | None:
|
||||
"""Get an entity."""
|
||||
for platform in self._platforms.values():
|
||||
entity_obj = platform.entities.get(entity_id)
|
||||
if entity_obj is not None:
|
||||
return entity_obj
|
||||
return entity_obj # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def setup(self, config: ConfigType) -> None:
|
||||
|
@ -176,14 +179,14 @@ class EntityComponent:
|
|||
|
||||
async def async_extract_from_service(
|
||||
self, service_call: ServiceCall, expand_group: bool = True
|
||||
) -> list[entity.Entity]:
|
||||
) -> list[_EntityT]:
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
||||
Will return an empty list if entities specified but unknown.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return await service.async_extract_entities(
|
||||
return await service.async_extract_entities( # type: ignore[return-value]
|
||||
self.hass, self.entities, service_call, expand_group
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue