Add a service to reload config entries that can easily be called though automations (#46762)
This commit is contained in:
parent
6fb0e49335
commit
08db262972
5 changed files with 228 additions and 72 deletions
|
@ -21,15 +21,30 @@ from homeassistant.const import (
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized, UnknownUser
|
from homeassistant.exceptions import HomeAssistantError, Unauthorized, UnknownUser
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.service import async_extract_referenced_entity_ids
|
from homeassistant.helpers.service import (
|
||||||
|
async_extract_config_entry_ids,
|
||||||
|
async_extract_referenced_entity_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTR_ENTRY_ID = "entry_id"
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DOMAIN = ha.DOMAIN
|
DOMAIN = ha.DOMAIN
|
||||||
SERVICE_RELOAD_CORE_CONFIG = "reload_core_config"
|
SERVICE_RELOAD_CORE_CONFIG = "reload_core_config"
|
||||||
|
SERVICE_RELOAD_CONFIG_ENTRY = "reload_config_entry"
|
||||||
SERVICE_CHECK_CONFIG = "check_config"
|
SERVICE_CHECK_CONFIG = "check_config"
|
||||||
SERVICE_UPDATE_ENTITY = "update_entity"
|
SERVICE_UPDATE_ENTITY = "update_entity"
|
||||||
SERVICE_SET_LOCATION = "set_location"
|
SERVICE_SET_LOCATION = "set_location"
|
||||||
SCHEMA_UPDATE_ENTITY = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
|
SCHEMA_UPDATE_ENTITY = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
|
||||||
|
SCHEMA_RELOAD_CONFIG_ENTRY = vol.All(
|
||||||
|
vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional(ATTR_ENTRY_ID): str,
|
||||||
|
**cv.ENTITY_SERVICE_FIELDS,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
cv.has_at_least_one_key(ATTR_ENTRY_ID, *cv.ENTITY_SERVICE_FIELDS),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
|
async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
|
||||||
|
@ -203,4 +218,26 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
|
||||||
vol.Schema({ATTR_LATITUDE: cv.latitude, ATTR_LONGITUDE: cv.longitude}),
|
vol.Schema({ATTR_LATITUDE: cv.latitude, ATTR_LONGITUDE: cv.longitude}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_handle_reload_config_entry(call):
|
||||||
|
"""Service handler for reloading a config entry."""
|
||||||
|
reload_entries = set()
|
||||||
|
if ATTR_ENTRY_ID in call.data:
|
||||||
|
reload_entries.add(call.data[ATTR_ENTRY_ID])
|
||||||
|
reload_entries.update(await async_extract_config_entry_ids(hass, call))
|
||||||
|
if not reload_entries:
|
||||||
|
raise ValueError("There were no matching config entries to reload")
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
hass.config_entries.async_reload(config_entry_id)
|
||||||
|
for config_entry_id in reload_entries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.helpers.service.async_register_admin_service(
|
||||||
|
ha.DOMAIN,
|
||||||
|
SERVICE_RELOAD_CONFIG_ENTRY,
|
||||||
|
async_handle_reload_config_entry,
|
||||||
|
schema=SCHEMA_RELOAD_CONFIG_ENTRY,
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -58,3 +58,19 @@ update_entity:
|
||||||
description: Force one or more entities to update its data
|
description: Force one or more entities to update its data
|
||||||
target:
|
target:
|
||||||
entity: {}
|
entity: {}
|
||||||
|
|
||||||
|
reload_config_entry:
|
||||||
|
name: Reload config entry
|
||||||
|
description: Reload a config entry that matches a target.
|
||||||
|
target:
|
||||||
|
entity: {}
|
||||||
|
device: {}
|
||||||
|
fields:
|
||||||
|
entry_id:
|
||||||
|
advanced: true
|
||||||
|
name: Config entry id
|
||||||
|
description: A configuration entry id
|
||||||
|
required: false
|
||||||
|
example: 8955375327824e14ba89e4b29cc3ec9a
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
|
|
@ -11,9 +11,9 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Tuple,
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
cast,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -78,6 +78,29 @@ class ServiceParams(TypedDict):
|
||||||
target: dict | None
|
target: dict | None
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceTargetSelector:
|
||||||
|
"""Class to hold a target selector for a service."""
|
||||||
|
|
||||||
|
def __init__(self, service_call: ha.ServiceCall):
|
||||||
|
"""Extract ids from service call data."""
|
||||||
|
entity_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_ENTITY_ID)
|
||||||
|
device_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_DEVICE_ID)
|
||||||
|
area_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_AREA_ID)
|
||||||
|
|
||||||
|
self.entity_ids = (
|
||||||
|
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
|
||||||
|
)
|
||||||
|
self.device_ids = (
|
||||||
|
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
|
||||||
|
)
|
||||||
|
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_any_selector(self) -> bool:
|
||||||
|
"""Determine if any selectors are present."""
|
||||||
|
return bool(self.entity_ids or self.device_ids or self.area_ids)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SelectedEntities:
|
class SelectedEntities:
|
||||||
"""Class to hold the selected entities."""
|
"""Class to hold the selected entities."""
|
||||||
|
@ -93,6 +116,9 @@ class SelectedEntities:
|
||||||
missing_devices: set[str] = dataclasses.field(default_factory=set)
|
missing_devices: set[str] = dataclasses.field(default_factory=set)
|
||||||
missing_areas: set[str] = dataclasses.field(default_factory=set)
|
missing_areas: set[str] = dataclasses.field(default_factory=set)
|
||||||
|
|
||||||
|
# Referenced devices
|
||||||
|
referenced_devices: set[str] = dataclasses.field(default_factory=set)
|
||||||
|
|
||||||
def log_missing(self, missing_entities: set[str]) -> None:
|
def log_missing(self, missing_entities: set[str]) -> None:
|
||||||
"""Log about missing items."""
|
"""Log about missing items."""
|
||||||
parts = []
|
parts = []
|
||||||
|
@ -293,98 +319,88 @@ async def async_extract_entity_ids(
|
||||||
return referenced.referenced | referenced.indirectly_referenced
|
return referenced.referenced | referenced.indirectly_referenced
|
||||||
|
|
||||||
|
|
||||||
|
def _has_match(ids: Optional[Union[str, list]]) -> bool:
|
||||||
|
"""Check if ids can match anything."""
|
||||||
|
return ids not in (None, ENTITY_MATCH_NONE)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_extract_referenced_entity_ids(
|
async def async_extract_referenced_entity_ids(
|
||||||
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||||
) -> SelectedEntities:
|
) -> SelectedEntities:
|
||||||
"""Extract referenced entity IDs from a service call."""
|
"""Extract referenced entity IDs from a service call."""
|
||||||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
selector = ServiceTargetSelector(service_call)
|
||||||
device_ids = service_call.data.get(ATTR_DEVICE_ID)
|
|
||||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
|
||||||
|
|
||||||
selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE)
|
|
||||||
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
|
|
||||||
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
|
|
||||||
|
|
||||||
selected = SelectedEntities()
|
selected = SelectedEntities()
|
||||||
|
|
||||||
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
|
if not selector.has_any_selector:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
if selects_entity_ids:
|
entity_ids = selector.entity_ids
|
||||||
assert entity_ids is not None
|
|
||||||
|
|
||||||
# Entity ID attr can be a list or a string
|
|
||||||
if isinstance(entity_ids, str):
|
|
||||||
entity_ids = [entity_ids]
|
|
||||||
|
|
||||||
if expand_group:
|
if expand_group:
|
||||||
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
||||||
|
|
||||||
selected.referenced.update(entity_ids)
|
selected.referenced.update(entity_ids)
|
||||||
|
|
||||||
if not selects_device_ids and not selects_area_ids:
|
if not selector.device_ids and not selector.area_ids:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
area_reg, dev_reg, ent_reg = cast(
|
ent_reg = entity_registry.async_get(hass)
|
||||||
Tuple[
|
dev_reg = device_registry.async_get(hass)
|
||||||
area_registry.AreaRegistry,
|
area_reg = area_registry.async_get(hass)
|
||||||
device_registry.DeviceRegistry,
|
|
||||||
entity_registry.EntityRegistry,
|
|
||||||
],
|
|
||||||
await asyncio.gather(
|
|
||||||
area_registry.async_get_registry(hass),
|
|
||||||
device_registry.async_get_registry(hass),
|
|
||||||
entity_registry.async_get_registry(hass),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
picked_devices = set()
|
for device_id in selector.device_ids:
|
||||||
|
|
||||||
if selects_device_ids:
|
|
||||||
if isinstance(device_ids, str):
|
|
||||||
picked_devices = {device_ids}
|
|
||||||
else:
|
|
||||||
assert isinstance(device_ids, list)
|
|
||||||
picked_devices = set(device_ids)
|
|
||||||
|
|
||||||
for device_id in picked_devices:
|
|
||||||
if device_id not in dev_reg.devices:
|
if device_id not in dev_reg.devices:
|
||||||
selected.missing_devices.add(device_id)
|
selected.missing_devices.add(device_id)
|
||||||
|
|
||||||
if selects_area_ids:
|
for area_id in selector.area_ids:
|
||||||
assert area_ids is not None
|
|
||||||
|
|
||||||
if isinstance(area_ids, str):
|
|
||||||
area_lookup = {area_ids}
|
|
||||||
else:
|
|
||||||
area_lookup = set(area_ids)
|
|
||||||
|
|
||||||
for area_id in area_lookup:
|
|
||||||
if area_id not in area_reg.areas:
|
if area_id not in area_reg.areas:
|
||||||
selected.missing_areas.add(area_id)
|
selected.missing_areas.add(area_id)
|
||||||
continue
|
|
||||||
|
|
||||||
# Find entities tied to an area
|
|
||||||
for entity_entry in ent_reg.entities.values():
|
|
||||||
if entity_entry.area_id in area_lookup:
|
|
||||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
|
||||||
|
|
||||||
# Find devices for this area
|
# Find devices for this area
|
||||||
|
selected.referenced_devices.update(selector.device_ids)
|
||||||
for device_entry in dev_reg.devices.values():
|
for device_entry in dev_reg.devices.values():
|
||||||
if device_entry.area_id in area_lookup:
|
if device_entry.area_id in selector.area_ids:
|
||||||
picked_devices.add(device_entry.id)
|
selected.referenced_devices.add(device_entry.id)
|
||||||
|
|
||||||
if not picked_devices:
|
if not selector.area_ids and not selected.referenced_devices:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
for entity_entry in ent_reg.entities.values():
|
for ent_entry in ent_reg.entities.values():
|
||||||
if not entity_entry.area_id and entity_entry.device_id in picked_devices:
|
if ent_entry.area_id in selector.area_ids or (
|
||||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
not ent_entry.area_id and ent_entry.device_id in selected.referenced_devices
|
||||||
|
):
|
||||||
|
selected.indirectly_referenced.add(ent_entry.entity_id)
|
||||||
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass
|
||||||
|
async def async_extract_config_entry_ids(
|
||||||
|
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||||
|
) -> set:
|
||||||
|
"""Extract referenced config entry ids from a service call."""
|
||||||
|
referenced = await async_extract_referenced_entity_ids(
|
||||||
|
hass, service_call, expand_group
|
||||||
|
)
|
||||||
|
ent_reg = entity_registry.async_get(hass)
|
||||||
|
dev_reg = device_registry.async_get(hass)
|
||||||
|
config_entry_ids: set[str] = set()
|
||||||
|
|
||||||
|
# Some devices may have no entities
|
||||||
|
for device_id in referenced.referenced_devices:
|
||||||
|
if device_id in dev_reg.devices:
|
||||||
|
device = dev_reg.async_get(device_id)
|
||||||
|
if device is not None:
|
||||||
|
config_entry_ids.update(device.config_entries)
|
||||||
|
|
||||||
|
for entity_id in referenced.referenced | referenced.indirectly_referenced:
|
||||||
|
entry = ent_reg.async_get(entity_id)
|
||||||
|
if entry is not None and entry.config_entry_id is not None:
|
||||||
|
config_entry_ids.add(entry.config_entry_id)
|
||||||
|
|
||||||
|
return config_entry_ids
|
||||||
|
|
||||||
|
|
||||||
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
|
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
|
||||||
"""Load services file for an integration."""
|
"""Load services file for an integration."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -11,6 +11,7 @@ import yaml
|
||||||
from homeassistant import config
|
from homeassistant import config
|
||||||
import homeassistant.components as comps
|
import homeassistant.components as comps
|
||||||
from homeassistant.components.homeassistant import (
|
from homeassistant.components.homeassistant import (
|
||||||
|
ATTR_ENTRY_ID,
|
||||||
SERVICE_CHECK_CONFIG,
|
SERVICE_CHECK_CONFIG,
|
||||||
SERVICE_RELOAD_CORE_CONFIG,
|
SERVICE_RELOAD_CORE_CONFIG,
|
||||||
SERVICE_SET_LOCATION,
|
SERVICE_SET_LOCATION,
|
||||||
|
@ -34,9 +35,11 @@ from homeassistant.helpers import entity
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
|
MockConfigEntry,
|
||||||
async_capture_events,
|
async_capture_events,
|
||||||
async_mock_service,
|
async_mock_service,
|
||||||
get_test_home_assistant,
|
get_test_home_assistant,
|
||||||
|
mock_registry,
|
||||||
mock_service,
|
mock_service,
|
||||||
patch_yaml_files,
|
patch_yaml_files,
|
||||||
)
|
)
|
||||||
|
@ -385,3 +388,62 @@ async def test_not_allowing_recursion(hass, caplog):
|
||||||
f"Called service homeassistant.{service} with invalid entities homeassistant.light"
|
f"Called service homeassistant.{service} with invalid entities homeassistant.light"
|
||||||
in caplog.text
|
in caplog.text
|
||||||
), service
|
), service
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_config_entry_by_entity_id(hass):
|
||||||
|
"""Test being able to reload a config entry by entity_id."""
|
||||||
|
await async_setup_component(hass, "homeassistant", {})
|
||||||
|
entity_reg = mock_registry(hass)
|
||||||
|
entry1 = MockConfigEntry(domain="mockdomain")
|
||||||
|
entry1.add_to_hass(hass)
|
||||||
|
entry2 = MockConfigEntry(domain="mockdomain")
|
||||||
|
entry2.add_to_hass(hass)
|
||||||
|
reg_entity1 = entity_reg.async_get_or_create(
|
||||||
|
"binary_sensor", "powerwall", "battery_charging", config_entry=entry1
|
||||||
|
)
|
||||||
|
reg_entity2 = entity_reg.async_get_or_create(
|
||||||
|
"binary_sensor", "powerwall", "battery_status", config_entry=entry2
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||||
|
return_value=None,
|
||||||
|
) as mock_reload:
|
||||||
|
await hass.services.async_call(
|
||||||
|
"homeassistant",
|
||||||
|
"reload_config_entry",
|
||||||
|
{"entity_id": f"{reg_entity1.entity_id},{reg_entity2.entity_id}"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_reload.mock_calls) == 2
|
||||||
|
assert {mock_reload.mock_calls[0][1][0], mock_reload.mock_calls[1][1][0]} == {
|
||||||
|
entry1.entry_id,
|
||||||
|
entry2.entry_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"homeassistant",
|
||||||
|
"reload_config_entry",
|
||||||
|
{"entity_id": "unknown.entity_id"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reload_config_entry_by_entry_id(hass):
|
||||||
|
"""Test being able to reload a config entry by config entry id."""
|
||||||
|
await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||||
|
return_value=None,
|
||||||
|
) as mock_reload:
|
||||||
|
await hass.services.async_call(
|
||||||
|
"homeassistant",
|
||||||
|
"reload_config_entry",
|
||||||
|
{ATTR_ENTRY_ID: "8955375327824e14ba89e4b29cc3ec9a"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_reload.mock_calls) == 1
|
||||||
|
assert mock_reload.mock_calls[0][1][0] == "8955375327824e14ba89e4b29cc3ec9a"
|
||||||
|
|
|
@ -1015,3 +1015,28 @@ async def test_async_extract_entities_warn_referenced(hass, caplog):
|
||||||
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
|
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
|
||||||
in caplog.text
|
in caplog.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_extract_config_entry_ids(hass):
|
||||||
|
"""Test we can find devices that have no entities."""
|
||||||
|
|
||||||
|
device_no_entities = dev_reg.DeviceEntry(
|
||||||
|
id="device-no-entities", config_entries={"abc"}
|
||||||
|
)
|
||||||
|
|
||||||
|
call = ha.ServiceCall(
|
||||||
|
"homeassistant",
|
||||||
|
"reload_config_entry",
|
||||||
|
{
|
||||||
|
"device_id": "device-no-entities",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_device_registry(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
device_no_entities.id: device_no_entities,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue