Make device automation type an enum (#62354)
This commit is contained in:
parent
2ddd45afd5
commit
334c6c5c02
12 changed files with 120 additions and 43 deletions
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
import logging
|
||||
from types import ModuleType
|
||||
|
@ -19,6 +20,7 @@ from homeassistant.helpers import (
|
|||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.helpers.frame import report
|
||||
from homeassistant.loader import IntegrationNotFound, bind_hass
|
||||
from homeassistant.requirements import async_get_integration_with_requirements
|
||||
|
||||
|
@ -45,32 +47,49 @@ class DeviceAutomationDetails(NamedTuple):
|
|||
get_capabilities_func: str
|
||||
|
||||
|
||||
TYPES = {
|
||||
"trigger": DeviceAutomationDetails(
|
||||
class DeviceAutomationType(Enum):
|
||||
"""Device automation type."""
|
||||
|
||||
TRIGGER = DeviceAutomationDetails(
|
||||
"device_trigger",
|
||||
"async_get_triggers",
|
||||
"async_get_trigger_capabilities",
|
||||
),
|
||||
"condition": DeviceAutomationDetails(
|
||||
)
|
||||
CONDITION = DeviceAutomationDetails(
|
||||
"device_condition",
|
||||
"async_get_conditions",
|
||||
"async_get_condition_capabilities",
|
||||
),
|
||||
"action": DeviceAutomationDetails(
|
||||
)
|
||||
ACTION = DeviceAutomationDetails(
|
||||
"device_action",
|
||||
"async_get_actions",
|
||||
"async_get_action_capabilities",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# TYPES is deprecated as of Home Assistant 2022.2, use DeviceAutomationType instead
|
||||
TYPES = {
|
||||
"trigger": DeviceAutomationType.TRIGGER.value,
|
||||
"condition": DeviceAutomationType.CONDITION.value,
|
||||
"action": DeviceAutomationType.ACTION.value,
|
||||
}
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_device_automations(
|
||||
hass: HomeAssistant,
|
||||
automation_type: str,
|
||||
automation_type: DeviceAutomationType | str,
|
||||
device_ids: Iterable[str] | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Return all the device automations for a type optionally limited to specific device ids."""
|
||||
if isinstance(automation_type, str):
|
||||
report(
|
||||
"uses str for async_get_device_automations automation_type. This is "
|
||||
"deprecated and will stop working in Home Assistant 2022.4, it should be "
|
||||
"updated to use DeviceAutomationType instead",
|
||||
error_if_core=False,
|
||||
)
|
||||
automation_type = DeviceAutomationType[automation_type.upper()]
|
||||
return await _async_get_device_automations(hass, automation_type, device_ids)
|
||||
|
||||
|
||||
|
@ -98,13 +117,21 @@ async def async_setup(hass, config):
|
|||
|
||||
|
||||
async def async_get_device_automation_platform(
|
||||
hass: HomeAssistant, domain: str, automation_type: str
|
||||
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
|
||||
) -> ModuleType:
|
||||
"""Load device automation platform for integration.
|
||||
|
||||
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
|
||||
"""
|
||||
platform_name = TYPES[automation_type].section
|
||||
if isinstance(automation_type, str):
|
||||
report(
|
||||
"uses str for async_get_device_automation_platform automation_type. This "
|
||||
"is deprecated and will stop working in Home Assistant 2022.4, it should "
|
||||
"be updated to use DeviceAutomationType instead",
|
||||
error_if_core=False,
|
||||
)
|
||||
automation_type = DeviceAutomationType[automation_type.upper()]
|
||||
platform_name = automation_type.value.section
|
||||
try:
|
||||
integration = await async_get_integration_with_requirements(hass, domain)
|
||||
platform = integration.get_platform(platform_name)
|
||||
|
@ -114,7 +141,8 @@ async def async_get_device_automation_platform(
|
|||
) from err
|
||||
except ImportError as err:
|
||||
raise InvalidDeviceAutomationConfig(
|
||||
f"Integration '{domain}' does not support device automation {automation_type}s"
|
||||
f"Integration '{domain}' does not support device automation "
|
||||
f"{automation_type.name.lower()}s"
|
||||
) from err
|
||||
|
||||
return platform
|
||||
|
@ -131,7 +159,7 @@ async def _async_get_device_automations_from_domain(
|
|||
except InvalidDeviceAutomationConfig:
|
||||
return {}
|
||||
|
||||
function_name = TYPES[automation_type].get_automations_func
|
||||
function_name = automation_type.value.get_automations_func
|
||||
|
||||
return await asyncio.gather(
|
||||
*(
|
||||
|
@ -143,7 +171,9 @@ async def _async_get_device_automations_from_domain(
|
|||
|
||||
|
||||
async def _async_get_device_automations(
|
||||
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
|
||||
hass: HomeAssistant,
|
||||
automation_type: DeviceAutomationType,
|
||||
device_ids: Iterable[str] | None,
|
||||
) -> Mapping[str, list[dict[str, Any]]]:
|
||||
"""List device automations."""
|
||||
device_registry = dr.async_get(hass)
|
||||
|
@ -188,7 +218,7 @@ async def _async_get_device_automations(
|
|||
if isinstance(device_results, Exception):
|
||||
logging.getLogger(__name__).error(
|
||||
"Unexpected error fetching device %ss",
|
||||
automation_type,
|
||||
automation_type.name.lower(),
|
||||
exc_info=device_results,
|
||||
)
|
||||
continue
|
||||
|
@ -207,7 +237,9 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
|
|||
except InvalidDeviceAutomationConfig:
|
||||
return {}
|
||||
|
||||
function_name = TYPES[automation_type].get_capabilities_func
|
||||
if isinstance(automation_type, str): # until tests pass DeviceAutomationType
|
||||
automation_type = DeviceAutomationType[automation_type.upper()]
|
||||
function_name = automation_type.value.get_capabilities_func
|
||||
|
||||
if not hasattr(platform, function_name):
|
||||
# The device automation has no capabilities
|
||||
|
@ -256,9 +288,11 @@ def handle_device_errors(func):
|
|||
async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||
"""Handle request for device actions."""
|
||||
device_id = msg["device_id"]
|
||||
actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
|
||||
device_id
|
||||
)
|
||||
actions = (
|
||||
await _async_get_device_automations(
|
||||
hass, DeviceAutomationType.ACTION, [device_id]
|
||||
)
|
||||
).get(device_id)
|
||||
connection.send_result(msg["id"], actions)
|
||||
|
||||
|
||||
|
@ -274,7 +308,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
|
|||
"""Handle request for device conditions."""
|
||||
device_id = msg["device_id"]
|
||||
conditions = (
|
||||
await _async_get_device_automations(hass, "condition", [device_id])
|
||||
await _async_get_device_automations(
|
||||
hass, DeviceAutomationType.CONDITION, [device_id]
|
||||
)
|
||||
).get(device_id)
|
||||
connection.send_result(msg["id"], conditions)
|
||||
|
||||
|
@ -290,9 +326,11 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
|
|||
async def websocket_device_automation_list_triggers(hass, connection, msg):
|
||||
"""Handle request for device triggers."""
|
||||
device_id = msg["device_id"]
|
||||
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
|
||||
device_id
|
||||
)
|
||||
triggers = (
|
||||
await _async_get_device_automations(
|
||||
hass, DeviceAutomationType.TRIGGER, [device_id]
|
||||
)
|
||||
).get(device_id)
|
||||
connection.send_result(msg["id"], triggers)
|
||||
|
||||
|
||||
|
@ -308,7 +346,7 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
|
|||
"""Handle request for device action capabilities."""
|
||||
action = msg["action"]
|
||||
capabilities = await _async_get_device_automation_capabilities(
|
||||
hass, "action", action
|
||||
hass, DeviceAutomationType.ACTION, action
|
||||
)
|
||||
connection.send_result(msg["id"], capabilities)
|
||||
|
||||
|
@ -327,7 +365,7 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
|
|||
"""Handle request for device condition capabilities."""
|
||||
condition = msg["condition"]
|
||||
capabilities = await _async_get_device_automation_capabilities(
|
||||
hass, "condition", condition
|
||||
hass, DeviceAutomationType.CONDITION, condition
|
||||
)
|
||||
connection.send_result(msg["id"], capabilities)
|
||||
|
||||
|
@ -346,6 +384,6 @@ async def websocket_device_automation_get_trigger_capabilities(hass, connection,
|
|||
"""Handle request for device trigger capabilities."""
|
||||
trigger = msg["trigger"]
|
||||
capabilities = await _async_get_device_automation_capabilities(
|
||||
hass, "trigger", trigger
|
||||
hass, DeviceAutomationType.TRIGGER, trigger
|
||||
)
|
||||
connection.send_result(msg["id"], capabilities)
|
||||
|
|
|
@ -3,7 +3,11 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import CONF_DOMAIN
|
||||
|
||||
from . import DEVICE_TRIGGER_BASE_SCHEMA, async_get_device_automation_platform
|
||||
from . import (
|
||||
DEVICE_TRIGGER_BASE_SCHEMA,
|
||||
DeviceAutomationType,
|
||||
async_get_device_automation_platform,
|
||||
)
|
||||
from .exceptions import InvalidDeviceAutomationConfig
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
@ -14,7 +18,7 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
|||
async def async_validate_trigger_config(hass, config):
|
||||
"""Validate config."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], "trigger"
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||
)
|
||||
if not hasattr(platform, "async_validate_trigger_config"):
|
||||
return platform.TRIGGER_SCHEMA(config)
|
||||
|
@ -28,6 +32,6 @@ async def async_validate_trigger_config(hass, config):
|
|||
async def async_attach_trigger(hass, config, action, automation_info):
|
||||
"""Listen for trigger."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], "trigger"
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
|
||||
)
|
||||
return await platform.async_attach_trigger(hass, config, action, automation_info)
|
||||
|
|
|
@ -819,7 +819,9 @@ class HomeKit:
|
|||
valid_device_ids.append(device_id)
|
||||
for device_id, device_triggers in (
|
||||
await device_automation.async_get_device_automations(
|
||||
self.hass, "trigger", valid_device_ids
|
||||
self.hass,
|
||||
device_automation.DeviceAutomationType.TRIGGER,
|
||||
valid_device_ids,
|
||||
)
|
||||
).items():
|
||||
self.add_bridge_triggers_accessory(
|
||||
|
|
|
@ -512,7 +512,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
|
|||
|
||||
async def _async_get_supported_devices(hass):
|
||||
"""Return all supported devices."""
|
||||
results = await device_automation.async_get_device_automations(hass, "trigger")
|
||||
results = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.TRIGGER
|
||||
)
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
unsorted = {
|
||||
device_id: dev_reg.async_get(device_id).name or device_id
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import Any, Callable, cast
|
|||
|
||||
from homeassistant.components import zone as zone_cmp
|
||||
from homeassistant.components.device_automation import (
|
||||
DeviceAutomationType,
|
||||
async_get_device_automation_platform,
|
||||
)
|
||||
from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP
|
||||
|
@ -881,7 +882,7 @@ async def async_device_from_config(
|
|||
) -> ConditionCheckerType:
|
||||
"""Test a device condition."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], "condition"
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
return trace_condition_function(
|
||||
cast(
|
||||
|
@ -952,7 +953,7 @@ async def async_validate_condition_config(
|
|||
config = cv.DEVICE_CONDITION_SCHEMA(config)
|
||||
assert not isinstance(config, Template)
|
||||
platform = await async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], "condition"
|
||||
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
if hasattr(platform, "async_validate_condition_config"):
|
||||
return await platform.async_validate_condition_config(hass, config) # type: ignore
|
||||
|
|
|
@ -254,7 +254,7 @@ async def async_validate_action_config(
|
|||
|
||||
elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
|
||||
platform = await device_automation.async_get_device_automation_platform(
|
||||
hass, config[CONF_DOMAIN], "action"
|
||||
hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION
|
||||
)
|
||||
if hasattr(platform, "async_validate_action_config"):
|
||||
config = await platform.async_validate_action_config(hass, config) # type: ignore
|
||||
|
@ -590,7 +590,9 @@ class _ScriptRun:
|
|||
"""Perform the device automation specified in the action."""
|
||||
self._step_log("device automation")
|
||||
platform = await device_automation.async_get_device_automation_platform(
|
||||
self._hass, self._action[CONF_DOMAIN], "action"
|
||||
self._hass,
|
||||
self._action[CONF_DOMAIN],
|
||||
device_automation.DeviceAutomationType.ACTION,
|
||||
)
|
||||
await platform.async_call_action_from_config(
|
||||
self._hass, self._action, self._variables, self._context
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
|
||||
from homeassistant.components import automation
|
||||
from homeassistant.components.NEW_DOMAIN import DOMAIN
|
||||
from homeassistant.components.device_automation import DeviceAutomationType
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry, entity_registry
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -56,7 +57,9 @@ async def test_get_actions(
|
|||
"entity_id": "NEW_DOMAIN.test_5678",
|
||||
},
|
||||
]
|
||||
actions = await async_get_device_automations(hass, "action", device_entry.id)
|
||||
actions = await async_get_device_automations(
|
||||
hass, DeviceAutomationType.ACTION, device_entry.id
|
||||
)
|
||||
assert_lists_same(actions, expected_actions)
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from homeassistant.components import automation
|
||||
from homeassistant.components.NEW_DOMAIN import DOMAIN
|
||||
from homeassistant.components.device_automation import DeviceAutomationType
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import device_registry, entity_registry
|
||||
|
@ -67,7 +68,9 @@ async def test_get_conditions(
|
|||
"entity_id": f"{DOMAIN}.test_5678",
|
||||
},
|
||||
]
|
||||
conditions = await async_get_device_automations(hass, "condition", device_entry.id)
|
||||
conditions = await async_get_device_automations(
|
||||
hass, DeviceAutomationType.CONDITION, device_entry.id
|
||||
)
|
||||
assert_lists_same(conditions, expected_conditions)
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
|
||||
from homeassistant.components import automation
|
||||
from homeassistant.components.NEW_DOMAIN import DOMAIN
|
||||
from homeassistant.components.device_automation import DeviceAutomationType
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.helpers import device_registry
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -60,7 +61,9 @@ async def test_get_triggers(hass, device_reg, entity_reg):
|
|||
"entity_id": f"{DOMAIN}.test_5678",
|
||||
},
|
||||
]
|
||||
triggers = await async_get_device_automations(hass, "trigger", device_entry.id)
|
||||
triggers = await async_get_device_automations(
|
||||
hass, DeviceAutomationType.TRIGGER, device_entry.id
|
||||
)
|
||||
assert_lists_same(triggers, expected_triggers)
|
||||
|
||||
|
||||
|
|
|
@ -69,7 +69,9 @@ CLIENT_REDIRECT_URI = "https://example.com/app/callback"
|
|||
|
||||
|
||||
async def async_get_device_automations(
|
||||
hass: HomeAssistant, automation_type: str, device_id: str
|
||||
hass: HomeAssistant,
|
||||
automation_type: device_automation.DeviceAutomationType | str,
|
||||
device_id: str,
|
||||
) -> Any:
|
||||
"""Get a device automation for a single device id."""
|
||||
automations = await device_automation.async_get_device_automations(
|
||||
|
|
|
@ -391,6 +391,13 @@ async def test_async_get_device_automations_single_device_trigger(
|
|||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.TRIGGER, [device_entry.id]
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
# Test deprecated str automation_type works, to be removed in 2022.4
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, "trigger", [device_entry.id]
|
||||
)
|
||||
|
@ -410,7 +417,9 @@ async def test_async_get_device_automations_all_devices_trigger(
|
|||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "trigger")
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.TRIGGER
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
|
@ -427,7 +436,9 @@ async def test_async_get_device_automations_all_devices_condition(
|
|||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "condition")
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.CONDITION
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 2
|
||||
|
||||
|
@ -444,7 +455,9 @@ async def test_async_get_device_automations_all_devices_action(
|
|||
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
|
||||
)
|
||||
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
|
||||
result = await device_automation.async_get_device_automations(hass, "action")
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.ACTION
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 3
|
||||
|
||||
|
@ -465,7 +478,9 @@ async def test_async_get_device_automations_all_devices_action_exception_throw(
|
|||
"homeassistant.components.light.device_trigger.async_get_triggers",
|
||||
side_effect=KeyError,
|
||||
):
|
||||
result = await device_automation.async_get_device_automations(hass, "trigger")
|
||||
result = await device_automation.async_get_device_automations(
|
||||
hass, device_automation.DeviceAutomationType.TRIGGER
|
||||
)
|
||||
assert device_entry.id in result
|
||||
assert len(result[device_entry.id]) == 0
|
||||
assert "KeyError" in caplog.text
|
||||
|
|
|
@ -16,7 +16,9 @@ async def test_get_actions(hass, push_registration):
|
|||
]
|
||||
|
||||
capabilitites = await device_automation._async_get_device_automation_capabilities(
|
||||
hass, "action", {"domain": DOMAIN, "device_id": device_id, "type": "notify"}
|
||||
hass,
|
||||
device_automation.DeviceAutomationType.ACTION,
|
||||
{"domain": DOMAIN, "device_id": device_id, "type": "notify"},
|
||||
)
|
||||
assert "extra_fields" in capabilitites
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue