Add labels to service target (#113753)
This commit is contained in:
parent
541d4b78ac
commit
167e66d45c
4 changed files with 256 additions and 9 deletions
|
@ -19,6 +19,7 @@ from homeassistant.const import (
|
|||
ATTR_DEVICE_ID,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_FLOOR_ID,
|
||||
ATTR_LABEL_ID,
|
||||
CONF_ENTITY_ID,
|
||||
CONF_SERVICE,
|
||||
CONF_SERVICE_DATA,
|
||||
|
@ -55,6 +56,7 @@ from . import (
|
|||
device_registry,
|
||||
entity_registry,
|
||||
floor_registry,
|
||||
label_registry,
|
||||
template,
|
||||
translation,
|
||||
)
|
||||
|
@ -196,7 +198,7 @@ class ServiceParams(TypedDict):
|
|||
class ServiceTargetSelector:
|
||||
"""Class to hold a target selector for a service."""
|
||||
|
||||
__slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids")
|
||||
__slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids", "label_ids")
|
||||
|
||||
def __init__(self, service_call: ServiceCall) -> None:
|
||||
"""Extract ids from service call data."""
|
||||
|
@ -205,6 +207,7 @@ class ServiceTargetSelector:
|
|||
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
|
||||
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
|
||||
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
|
||||
label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID)
|
||||
|
||||
self.entity_ids = (
|
||||
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
|
||||
|
@ -216,12 +219,19 @@ class ServiceTargetSelector:
|
|||
self.floor_ids = (
|
||||
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
|
||||
)
|
||||
self.label_ids = (
|
||||
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
|
||||
)
|
||||
|
||||
@property
|
||||
def has_any_selector(self) -> bool:
|
||||
"""Determine if any selectors are present."""
|
||||
return bool(
|
||||
self.entity_ids or self.device_ids or self.area_ids or self.floor_ids
|
||||
self.entity_ids
|
||||
or self.device_ids
|
||||
or self.area_ids
|
||||
or self.floor_ids
|
||||
or self.label_ids
|
||||
)
|
||||
|
||||
|
||||
|
@ -232,7 +242,7 @@ class SelectedEntities:
|
|||
# Entities that were explicitly mentioned.
|
||||
referenced: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
# Entities that were referenced via device/area/floor ID.
|
||||
# Entities that were referenced via device/area/floor/label ID.
|
||||
# Should not trigger a warning when they don't exist.
|
||||
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
@ -240,6 +250,7 @@ class SelectedEntities:
|
|||
missing_devices: set[str] = dataclasses.field(default_factory=set)
|
||||
missing_areas: set[str] = dataclasses.field(default_factory=set)
|
||||
missing_floors: set[str] = dataclasses.field(default_factory=set)
|
||||
missing_labels: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
# Referenced devices
|
||||
referenced_devices: set[str] = dataclasses.field(default_factory=set)
|
||||
|
@ -253,6 +264,7 @@ class SelectedEntities:
|
|||
("areas", self.missing_areas),
|
||||
("devices", self.missing_devices),
|
||||
("entities", missing_entities),
|
||||
("labels", self.missing_labels),
|
||||
):
|
||||
if items:
|
||||
parts.append(f"{label} {', '.join(sorted(items))}")
|
||||
|
@ -467,7 +479,7 @@ def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
|
|||
|
||||
|
||||
@bind_hass
|
||||
def async_extract_referenced_entity_ids(
|
||||
def async_extract_referenced_entity_ids( # noqa: C901
|
||||
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
|
||||
) -> SelectedEntities:
|
||||
"""Extract referenced entity IDs from a service call."""
|
||||
|
@ -483,13 +495,19 @@ def async_extract_referenced_entity_ids(
|
|||
|
||||
selected.referenced.update(entity_ids)
|
||||
|
||||
if not selector.device_ids and not selector.area_ids and not selector.floor_ids:
|
||||
if (
|
||||
not selector.device_ids
|
||||
and not selector.area_ids
|
||||
and not selector.floor_ids
|
||||
and not selector.label_ids
|
||||
):
|
||||
return selected
|
||||
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
area_reg = area_registry.async_get(hass)
|
||||
floor_reg = floor_registry.async_get(hass)
|
||||
label_reg = label_registry.async_get(hass)
|
||||
|
||||
for floor_id in selector.floor_ids:
|
||||
if floor_id not in floor_reg.floors:
|
||||
|
@ -503,6 +521,28 @@ def async_extract_referenced_entity_ids(
|
|||
if device_id not in dev_reg.devices:
|
||||
selected.missing_devices.add(device_id)
|
||||
|
||||
for label_id in selector.label_ids:
|
||||
if label_id not in label_reg.labels:
|
||||
selected.missing_labels.add(label_id)
|
||||
|
||||
# Find areas, devices & entities for targeted labels
|
||||
if selector.label_ids:
|
||||
for area_entry in area_reg.areas.values():
|
||||
if area_entry.labels.intersection(selector.label_ids):
|
||||
selected.referenced_areas.add(area_entry.id)
|
||||
|
||||
for device_entry in dev_reg.devices.values():
|
||||
if device_entry.labels.intersection(selector.label_ids):
|
||||
selected.referenced_devices.add(device_entry.id)
|
||||
|
||||
for entity_entry in ent_reg.entities.values():
|
||||
if (
|
||||
entity_entry.entity_category is None
|
||||
and entity_entry.hidden_by is None
|
||||
and entity_entry.labels.intersection(selector.label_ids)
|
||||
):
|
||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||
|
||||
# Find areas for targeted floors
|
||||
if selector.floor_ids:
|
||||
for area_entry in area_reg.areas.values():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue