Make device automation type an enum (#62354)

This commit is contained in:
Ville Skyttä 2021-12-20 20:16:30 +02:00 committed by GitHub
parent 2ddd45afd5
commit 334c6c5c02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 120 additions and 43 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Iterable, Mapping
from enum import Enum
from functools import wraps
import logging
from types import ModuleType
@ -19,6 +20,7 @@ from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.frame import report
from homeassistant.loader import IntegrationNotFound, bind_hass
from homeassistant.requirements import async_get_integration_with_requirements
@ -45,32 +47,49 @@ class DeviceAutomationDetails(NamedTuple):
get_capabilities_func: str
TYPES = {
"trigger": DeviceAutomationDetails(
class DeviceAutomationType(Enum):
"""Device automation type."""
TRIGGER = DeviceAutomationDetails(
"device_trigger",
"async_get_triggers",
"async_get_trigger_capabilities",
),
"condition": DeviceAutomationDetails(
)
CONDITION = DeviceAutomationDetails(
"device_condition",
"async_get_conditions",
"async_get_condition_capabilities",
),
"action": DeviceAutomationDetails(
)
ACTION = DeviceAutomationDetails(
"device_action",
"async_get_actions",
"async_get_action_capabilities",
),
)
# TYPES is deprecated as of Home Assistant 2022.2, use DeviceAutomationType instead
TYPES = {
"trigger": DeviceAutomationType.TRIGGER.value,
"condition": DeviceAutomationType.CONDITION.value,
"action": DeviceAutomationType.ACTION.value,
}
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
automation_type: str,
automation_type: DeviceAutomationType | str,
device_ids: Iterable[str] | None = None,
) -> Mapping[str, Any]:
"""Return all the device automations for a type optionally limited to specific device ids."""
if isinstance(automation_type, str):
report(
"uses str for async_get_device_automations automation_type. This is "
"deprecated and will stop working in Home Assistant 2022.4, it should be "
"updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
return await _async_get_device_automations(hass, automation_type, device_ids)
@ -98,13 +117,21 @@ async def async_setup(hass, config):
async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: str
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> ModuleType:
"""Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
"""
platform_name = TYPES[automation_type].section
if isinstance(automation_type, str):
report(
"uses str for async_get_device_automation_platform automation_type. This "
"is deprecated and will stop working in Home Assistant 2022.4, it should "
"be updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
platform_name = automation_type.value.section
try:
integration = await async_get_integration_with_requirements(hass, domain)
platform = integration.get_platform(platform_name)
@ -114,7 +141,8 @@ async def async_get_device_automation_platform(
) from err
except ImportError as err:
raise InvalidDeviceAutomationConfig(
f"Integration '{domain}' does not support device automation {automation_type}s"
f"Integration '{domain}' does not support device automation "
f"{automation_type.name.lower()}s"
) from err
return platform
@ -131,7 +159,7 @@ async def _async_get_device_automations_from_domain(
except InvalidDeviceAutomationConfig:
return {}
function_name = TYPES[automation_type].get_automations_func
function_name = automation_type.value.get_automations_func
return await asyncio.gather(
*(
@ -143,7 +171,9 @@ async def _async_get_device_automations_from_domain(
async def _async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
hass: HomeAssistant,
automation_type: DeviceAutomationType,
device_ids: Iterable[str] | None,
) -> Mapping[str, list[dict[str, Any]]]:
"""List device automations."""
device_registry = dr.async_get(hass)
@ -188,7 +218,7 @@ async def _async_get_device_automations(
if isinstance(device_results, Exception):
logging.getLogger(__name__).error(
"Unexpected error fetching device %ss",
automation_type,
automation_type.name.lower(),
exc_info=device_results,
)
continue
@ -207,7 +237,9 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
except InvalidDeviceAutomationConfig:
return {}
function_name = TYPES[automation_type].get_capabilities_func
if isinstance(automation_type, str): # until tests pass DeviceAutomationType
automation_type = DeviceAutomationType[automation_type.upper()]
function_name = automation_type.value.get_capabilities_func
if not hasattr(platform, function_name):
# The device automation has no capabilities
@ -256,9 +288,11 @@ def handle_device_errors(func):
async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions."""
device_id = msg["device_id"]
actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
device_id
)
actions = (
await _async_get_device_automations(
hass, DeviceAutomationType.ACTION, [device_id]
)
).get(device_id)
connection.send_result(msg["id"], actions)
@ -274,7 +308,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions."""
device_id = msg["device_id"]
conditions = (
await _async_get_device_automations(hass, "condition", [device_id])
await _async_get_device_automations(
hass, DeviceAutomationType.CONDITION, [device_id]
)
).get(device_id)
connection.send_result(msg["id"], conditions)
@ -290,9 +326,11 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers."""
device_id = msg["device_id"]
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
device_id
)
triggers = (
await _async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, [device_id]
)
).get(device_id)
connection.send_result(msg["id"], triggers)
@ -308,7 +346,7 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
"""Handle request for device action capabilities."""
action = msg["action"]
capabilities = await _async_get_device_automation_capabilities(
hass, "action", action
hass, DeviceAutomationType.ACTION, action
)
connection.send_result(msg["id"], capabilities)
@ -327,7 +365,7 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
"""Handle request for device condition capabilities."""
condition = msg["condition"]
capabilities = await _async_get_device_automation_capabilities(
hass, "condition", condition
hass, DeviceAutomationType.CONDITION, condition
)
connection.send_result(msg["id"], capabilities)
@ -346,6 +384,6 @@ async def websocket_device_automation_get_trigger_capabilities(hass, connection,
"""Handle request for device trigger capabilities."""
trigger = msg["trigger"]
capabilities = await _async_get_device_automation_capabilities(
hass, "trigger", trigger
hass, DeviceAutomationType.TRIGGER, trigger
)
connection.send_result(msg["id"], capabilities)

View file

@ -3,7 +3,11 @@ import voluptuous as vol
from homeassistant.const import CONF_DOMAIN
from . import DEVICE_TRIGGER_BASE_SCHEMA, async_get_device_automation_platform
from . import (
DEVICE_TRIGGER_BASE_SCHEMA,
DeviceAutomationType,
async_get_device_automation_platform,
)
from .exceptions import InvalidDeviceAutomationConfig
# mypy: allow-untyped-defs, no-check-untyped-defs
@ -14,7 +18,7 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
async def async_validate_trigger_config(hass, config):
"""Validate config."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "trigger"
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return platform.TRIGGER_SCHEMA(config)
@ -28,6 +32,6 @@ async def async_validate_trigger_config(hass, config):
async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for trigger."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "trigger"
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
return await platform.async_attach_trigger(hass, config, action, automation_info)

View file

@ -819,7 +819,9 @@ class HomeKit:
valid_device_ids.append(device_id)
for device_id, device_triggers in (
await device_automation.async_get_device_automations(
self.hass, "trigger", valid_device_ids
self.hass,
device_automation.DeviceAutomationType.TRIGGER,
valid_device_ids,
)
).items():
self.add_bridge_triggers_accessory(

View file

@ -512,7 +512,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
async def _async_get_supported_devices(hass):
"""Return all supported devices."""
results = await device_automation.async_get_device_automations(hass, "trigger")
results = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
dev_reg = device_registry.async_get(hass)
unsorted = {
device_id: dev_reg.async_get(device_id).name or device_id

View file

@ -14,6 +14,7 @@ from typing import Any, Callable, cast
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
DeviceAutomationType,
async_get_device_automation_platform,
)
from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP
@ -881,7 +882,7 @@ async def async_device_from_config(
) -> ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(
cast(
@ -952,7 +953,7 @@ async def async_validate_condition_config(
config = cv.DEVICE_CONDITION_SCHEMA(config)
assert not isinstance(config, Template)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config) # type: ignore

View file

@ -254,7 +254,7 @@ async def async_validate_action_config(
elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action"
hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION
)
if hasattr(platform, "async_validate_action_config"):
config = await platform.async_validate_action_config(hass, config) # type: ignore
@ -590,7 +590,9 @@ class _ScriptRun:
"""Perform the device automation specified in the action."""
self._step_log("device automation")
platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action"
self._hass,
self._action[CONF_DOMAIN],
device_automation.DeviceAutomationType.ACTION,
)
await platform.async_call_action_from_config(
self._hass, self._action, self._variables, self._context

View file

@ -3,6 +3,7 @@ import pytest
from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry, entity_registry
from homeassistant.setup import async_setup_component
@ -56,7 +57,9 @@ async def test_get_actions(
"entity_id": "NEW_DOMAIN.test_5678",
},
]
actions = await async_get_device_automations(hass, "action", device_entry.id)
actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, device_entry.id
)
assert_lists_same(actions, expected_actions)

View file

@ -5,6 +5,7 @@ import pytest
from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import device_registry, entity_registry
@ -67,7 +68,9 @@ async def test_get_conditions(
"entity_id": f"{DOMAIN}.test_5678",
},
]
conditions = await async_get_device_automations(hass, "condition", device_entry.id)
conditions = await async_get_device_automations(
hass, DeviceAutomationType.CONDITION, device_entry.id
)
assert_lists_same(conditions, expected_conditions)

View file

@ -3,6 +3,7 @@ import pytest
from homeassistant.components import automation
from homeassistant.components.NEW_DOMAIN import DOMAIN
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry
from homeassistant.setup import async_setup_component
@ -60,7 +61,9 @@ async def test_get_triggers(hass, device_reg, entity_reg):
"entity_id": f"{DOMAIN}.test_5678",
},
]
triggers = await async_get_device_automations(hass, "trigger", device_entry.id)
triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device_entry.id
)
assert_lists_same(triggers, expected_triggers)

View file

@ -69,7 +69,9 @@ CLIENT_REDIRECT_URI = "https://example.com/app/callback"
async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str
hass: HomeAssistant,
automation_type: device_automation.DeviceAutomationType | str,
device_id: str,
) -> Any:
"""Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations(

View file

@ -391,6 +391,13 @@ async def test_async_get_device_automations_single_device_trigger(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER, [device_entry.id]
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
# Test deprecated str automation_type works, to be removed in 2022.4
result = await device_automation.async_get_device_automations(
hass, "trigger", [device_entry.id]
)
@ -410,7 +417,9 @@ async def test_async_get_device_automations_all_devices_trigger(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "trigger")
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
@ -427,7 +436,9 @@ async def test_async_get_device_automations_all_devices_condition(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "condition")
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.CONDITION
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
@ -444,7 +455,9 @@ async def test_async_get_device_automations_all_devices_action(
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "action")
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.ACTION
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 3
@ -465,7 +478,9 @@ async def test_async_get_device_automations_all_devices_action_exception_throw(
"homeassistant.components.light.device_trigger.async_get_triggers",
side_effect=KeyError,
):
result = await device_automation.async_get_device_automations(hass, "trigger")
result = await device_automation.async_get_device_automations(
hass, device_automation.DeviceAutomationType.TRIGGER
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 0
assert "KeyError" in caplog.text

View file

@ -16,7 +16,9 @@ async def test_get_actions(hass, push_registration):
]
capabilitites = await device_automation._async_get_device_automation_capabilities(
hass, "action", {"domain": DOMAIN, "device_id": device_id, "type": "notify"}
hass,
device_automation.DeviceAutomationType.ACTION,
{"domain": DOMAIN, "device_id": device_id, "type": "notify"},
)
assert "extra_fields" in capabilitites