From 9d192643eef3af69c6d8dd6c615d83196898bb58 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 22 Nov 2022 11:35:18 +0100 Subject: [PATCH] Move PluggableAction to trigger helpers (#81900) Co-authored-by: Martin Hjelmare Co-authored-by: Joakim Plate Co-authored-by: Shay Levy --- .coveragerc | 1 + .../components/philips_js/__init__.py | 44 +------ homeassistant/components/philips_js/const.py | 2 + .../components/philips_js/device_trigger.py | 40 +++---- .../components/philips_js/helpers.py | 16 +++ .../components/philips_js/media_player.py | 20 +++- homeassistant/components/philips_js/remote.py | 14 ++- homeassistant/components/webostv/__init__.py | 50 +------- .../components/webostv/device_trigger.py | 17 +-- .../components/webostv/media_player.py | 15 ++- .../components/webostv/triggers/turn_on.py | 33 +++++- homeassistant/helpers/trigger.py | 109 ++++++++++++++++++ tests/components/philips_js/conftest.py | 6 +- .../philips_js/test_device_trigger.py | 1 + tests/helpers/test_trigger.py | 83 ++++++++++++- 15 files changed, 304 insertions(+), 147 deletions(-) create mode 100644 homeassistant/components/philips_js/helpers.py diff --git a/.coveragerc b/.coveragerc index 413294573a0..98a9ee9c6ff 100644 --- a/.coveragerc +++ b/.coveragerc @@ -961,6 +961,7 @@ omit = homeassistant/components/pencom/switch.py homeassistant/components/philips_js/__init__.py homeassistant/components/philips_js/diagnostics.py + homeassistant/components/philips_js/helpers.py homeassistant/components/philips_js/light.py homeassistant/components/philips_js/media_player.py homeassistant/components/philips_js/remote.py diff --git a/homeassistant/components/philips_js/__init__.py b/homeassistant/components/philips_js/__init__.py index a31212be3f7..3287d907578 100644 --- a/homeassistant/components/philips_js/__init__.py +++ b/homeassistant/components/philips_js/__init__.py @@ -2,10 +2,9 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Mapping +from collections.abc import Mapping from datetime import timedelta import logging -from typing import Any from haphilipsjs import AutenticationFailure, ConnectionFailure, PhilipsTV from haphilipsjs.typing import SystemType @@ -18,9 +17,8 @@ from homeassistant.const import ( CONF_USERNAME, Platform, ) -from homeassistant.core import Context, HassJob, HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.debounce import Debouncer -from homeassistant.helpers.trigger import TriggerActionType from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import CONF_ALLOW_NOTIFY, CONF_SYSTEM, DOMAIN @@ -78,42 +76,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return unload_ok -class PluggableAction: - """A pluggable action handler.""" - - def __init__(self, update: Callable[[], None]) -> None: - """Initialize.""" - self._update = update - self._actions: dict[ - Any, tuple[HassJob[..., Coroutine[Any, Any, None]], dict[str, Any]] - ] = {} - - def __bool__(self): - """Return if we have something attached.""" - return bool(self._actions) - - @callback - def async_attach(self, action: TriggerActionType, variables: dict[str, Any]): - """Attach a device trigger for turn on.""" - - @callback - def _remove(): - del self._actions[_remove] - self._update() - - job = HassJob(action) - - self._actions[_remove] = (job, variables) - self._update() - - return _remove - - async def async_run(self, hass: HomeAssistant, context: Context | None = None): - """Run all turn on triggers.""" - for job, variables in self._actions.values(): - hass.async_run_hass_job(job, variables, context) - - class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]): """Coordinator to update data.""" @@ -125,8 +87,6 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]): self.options = options self._notify_future: asyncio.Task | None = None - self.turn_on = PluggableAction(self.async_update_listeners) - super().__init__( hass, LOGGER, diff --git a/homeassistant/components/philips_js/const.py b/homeassistant/components/philips_js/const.py index 5d1141a8fb9..7788634ebc0 100644 --- a/homeassistant/components/philips_js/const.py +++ b/homeassistant/components/philips_js/const.py @@ -6,3 +6,5 @@ CONF_ALLOW_NOTIFY = "allow_notify" CONST_APP_ID = "homeassistant.io" CONST_APP_NAME = "Home Assistant" + +TRIGGER_TYPE_TURN_ON = "turn_on" diff --git a/homeassistant/components/philips_js/device_trigger.py b/homeassistant/components/philips_js/device_trigger.py index d7ce9807d64..bdf47674bc8 100644 --- a/homeassistant/components/philips_js/device_trigger.py +++ b/homeassistant/components/philips_js/device_trigger.py @@ -4,17 +4,18 @@ from __future__ import annotations import voluptuous as vol from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA -from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE +from homeassistant.const import CONF_DEVICE_ID, CONF_TYPE from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr -from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo +from homeassistant.helpers.trigger import ( + PluggableAction, + TriggerActionType, + TriggerInfo, +) from homeassistant.helpers.typing import ConfigType -from . import PhilipsTVDataUpdateCoordinator -from .const import DOMAIN - -TRIGGER_TYPE_TURN_ON = "turn_on" +from .const import DOMAIN, TRIGGER_TYPE_TURN_ON +from .helpers import async_get_turn_on_trigger TRIGGER_TYPES = {TRIGGER_TYPE_TURN_ON} TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( @@ -29,14 +30,7 @@ async def async_get_triggers( ) -> list[dict[str, str]]: """List device triggers for device.""" triggers = [] - triggers.append( - { - CONF_PLATFORM: "device", - CONF_DEVICE_ID: device_id, - CONF_DOMAIN: DOMAIN, - CONF_TYPE: TRIGGER_TYPE_TURN_ON, - } - ) + triggers.append(async_get_turn_on_trigger(device_id)) return triggers @@ -49,7 +43,6 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Attach a trigger.""" trigger_data = trigger_info["trigger_data"] - registry: dr.DeviceRegistry = dr.async_get(hass) if (trigger_type := config[CONF_TYPE]) == TRIGGER_TYPE_TURN_ON: variables = { "trigger": { @@ -61,16 +54,9 @@ async def async_attach_trigger( } } - device = registry.async_get(config[CONF_DEVICE_ID]) - if device is None: - raise HomeAssistantError( - f"Device id {config[CONF_DEVICE_ID]} not found in registry" - ) - for config_entry_id in device.config_entries: - coordinator: PhilipsTVDataUpdateCoordinator = hass.data[DOMAIN].get( - config_entry_id - ) - if coordinator: - return coordinator.turn_on.async_attach(action, variables) + turn_on_trigger = async_get_turn_on_trigger(config[CONF_DEVICE_ID]) + return PluggableAction.async_attach_trigger( + hass, turn_on_trigger, action, variables + ) raise HomeAssistantError(f"Unhandled trigger type {trigger_type}") diff --git a/homeassistant/components/philips_js/helpers.py b/homeassistant/components/philips_js/helpers.py new file mode 100644 index 00000000000..010ca7b9a19 --- /dev/null +++ b/homeassistant/components/philips_js/helpers.py @@ -0,0 +1,16 @@ +"""Helpers for philips_js.""" + +from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE + +from .const import DOMAIN, TRIGGER_TYPE_TURN_ON + + +def async_get_turn_on_trigger(device_id: str) -> dict[str, str]: + """Return trigger description for a turn on trigger.""" + + return { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: TRIGGER_TYPE_TURN_ON, + } diff --git a/homeassistant/components/philips_js/media_player.py b/homeassistant/components/philips_js/media_player.py index e1ceddd4bda..04e63008e7b 100644 --- a/homeassistant/components/philips_js/media_player.py +++ b/homeassistant/components/philips_js/media_player.py @@ -19,10 +19,12 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.trigger import PluggableAction from homeassistant.helpers.update_coordinator import CoordinatorEntity from . import LOGGER as _LOGGER, PhilipsTVDataUpdateCoordinator from .const import DOMAIN +from .helpers import async_get_turn_on_trigger SUPPORT_PHILIPS_JS = ( MediaPlayerEntityFeature.TURN_OFF @@ -39,8 +41,6 @@ SUPPORT_PHILIPS_JS = ( | MediaPlayerEntityFeature.STOP ) -CONF_ON_ACTION = "turn_on_action" - def _inverted(data): return {v: k for k, v in data.items()} @@ -95,9 +95,19 @@ class PhilipsTVMediaPlayer( self._media_title: str | None = None self._media_channel: str | None = None + self._turn_on = PluggableAction(self.async_write_ha_state) super().__init__(coordinator) self._update_from_coordinator() + async def async_added_to_hass(self) -> None: + """Handle being added to hass.""" + if (entry := self.registry_entry) and entry.device_id: + self.async_on_remove( + self._turn_on.async_register( + self.hass, async_get_turn_on_trigger(entry.device_id) + ) + ) + async def _async_update_soon(self): """Reschedule update task.""" self.async_write_ha_state() @@ -107,9 +117,7 @@ class PhilipsTVMediaPlayer( def supported_features(self) -> MediaPlayerEntityFeature: """Flag media player features that are supported.""" supports = self._supports - if self.coordinator.turn_on or ( - self._tv.on and self._tv.powerstate is not None - ): + if self._turn_on or (self._tv.on and self._tv.powerstate is not None): supports |= MediaPlayerEntityFeature.TURN_ON return supports @@ -152,7 +160,7 @@ class PhilipsTVMediaPlayer( await self._tv.setPowerState("On") self._state = MediaPlayerState.ON else: - await self.coordinator.turn_on.async_run(self.hass, self._context) + await self._turn_on.async_run(self.hass, self._context) await self._async_update_soon() async def async_turn_off(self) -> None: diff --git a/homeassistant/components/philips_js/remote.py b/homeassistant/components/philips_js/remote.py index 02d5e512a33..3496ec5f576 100644 --- a/homeassistant/components/philips_js/remote.py +++ b/homeassistant/components/philips_js/remote.py @@ -13,10 +13,12 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.trigger import PluggableAction from homeassistant.helpers.update_coordinator import CoordinatorEntity from . import LOGGER, PhilipsTVDataUpdateCoordinator from .const import DOMAIN +from .helpers import async_get_turn_on_trigger async def async_setup_entry( @@ -52,6 +54,16 @@ class PhilipsTVRemote(CoordinatorEntity[PhilipsTVDataUpdateCoordinator], RemoteE name=coordinator.system["name"], sw_version=coordinator.system.get("softwareversion"), ) + self._turn_on = PluggableAction(self.async_write_ha_state) + + async def async_added_to_hass(self) -> None: + """Handle being added to hass.""" + if (entry := self.registry_entry) and entry.device_id: + self.async_on_remove( + self._turn_on.async_register( + self.hass, async_get_turn_on_trigger(entry.device_id) + ) + ) @property def is_on(self): @@ -65,7 +77,7 @@ class PhilipsTVRemote(CoordinatorEntity[PhilipsTVDataUpdateCoordinator], RemoteE if self._tv.on and self._tv.powerstate: await self._tv.setPowerState("On") else: - await self.coordinator.turn_on.async_run(self.hass, self._context) + await self._turn_on.async_run(self.hass, self._context) self.async_write_ha_state() async def async_turn_off(self, **kwargs: Any) -> None: diff --git a/homeassistant/components/webostv/__init__.py b/homeassistant/components/webostv/__init__.py index 8b023990590..cd5485d4fd2 100644 --- a/homeassistant/components/webostv/__init__.py +++ b/homeassistant/components/webostv/__init__.py @@ -1,10 +1,8 @@ """Support for LG webOS Smart TV.""" from __future__ import annotations -from collections.abc import Callable, Coroutine from contextlib import suppress import logging -from typing import Any from aiowebostv import WebOsClient, WebOsTvPairError import voluptuous as vol @@ -19,17 +17,9 @@ from homeassistant.const import ( CONF_NAME, EVENT_HOMEASSISTANT_STOP, ) -from homeassistant.core import ( - Context, - Event, - HassJob, - HomeAssistant, - ServiceCall, - callback, -) +from homeassistant.core import Event, HomeAssistant, ServiceCall from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers.dispatcher import async_dispatcher_send -from homeassistant.helpers.trigger import TriggerActionType from homeassistant.helpers.typing import ConfigType from .const import ( @@ -165,43 +155,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return unload_ok -class PluggableAction: - """A pluggable action handler.""" - - def __init__(self) -> None: - """Initialize.""" - self._actions: dict[ - Callable[[], None], - tuple[HassJob[..., Coroutine[Any, Any, None]], dict[str, Any]], - ] = {} - - def __bool__(self) -> bool: - """Return if we have something attached.""" - return bool(self._actions) - - @callback - def async_attach( - self, action: TriggerActionType, variables: dict[str, Any] - ) -> Callable[[], None]: - """Attach a device trigger for turn on.""" - - @callback - def _remove() -> None: - del self._actions[_remove] - - job = HassJob(action) - - self._actions[_remove] = (job, variables) - - return _remove - - @callback - def async_run(self, hass: HomeAssistant, context: Context | None = None) -> None: - """Run all turn on triggers.""" - for job, variables in self._actions.values(): - hass.async_run_hass_job(job, variables, context) - - class WebOsClientWrapper: """Wrapper for a WebOS TV client with Home Assistant specific functions.""" @@ -209,7 +162,6 @@ class WebOsClientWrapper: """Set up the client.""" self.host = host self.client_key = client_key - self.turn_on = PluggableAction() self.client: WebOsClient | None = None async def connect(self) -> None: diff --git a/homeassistant/components/webostv/device_trigger.py b/homeassistant/components/webostv/device_trigger.py index 590cbc19de8..14854383ec8 100644 --- a/homeassistant/components/webostv/device_trigger.py +++ b/homeassistant/components/webostv/device_trigger.py @@ -7,7 +7,7 @@ from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEM from homeassistant.components.device_automation.exceptions import ( InvalidDeviceAutomationConfig, ) -from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE +from homeassistant.const import CONF_DEVICE_ID, CONF_PLATFORM, CONF_TYPE from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo @@ -19,7 +19,10 @@ from .helpers import ( async_get_client_wrapper_by_device_entry, async_get_device_entry_by_device_id, ) -from .triggers.turn_on import PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE +from .triggers.turn_on import ( + PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE, + async_get_turn_on_trigger, +) TRIGGER_TYPES = {TURN_ON_PLATFORM_TYPE} TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( @@ -51,15 +54,7 @@ async def async_get_triggers( _hass: HomeAssistant, device_id: str ) -> list[dict[str, str]]: """List device triggers for device.""" - triggers = [] - base_trigger = { - CONF_PLATFORM: "device", - CONF_DEVICE_ID: device_id, - CONF_DOMAIN: DOMAIN, - } - - triggers.append({**base_trigger, CONF_TYPE: TURN_ON_PLATFORM_TYPE}) - + triggers = [async_get_turn_on_trigger(device_id)] return triggers diff --git a/homeassistant/components/webostv/media_player.py b/homeassistant/components/webostv/media_player.py index 07d49f703f7..dcbec24c665 100644 --- a/homeassistant/components/webostv/media_player.py +++ b/homeassistant/components/webostv/media_player.py @@ -32,6 +32,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers.trigger import PluggableAction from . import WebOsClientWrapper from .const import ( @@ -43,6 +44,7 @@ from .const import ( LIVE_TV_APP_ID, WEBOSTV_EXCEPTIONS, ) +from .triggers.turn_on import async_get_turn_on_trigger _LOGGER = logging.getLogger(__name__) @@ -133,7 +135,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): # Assume that the TV is not paused self._paused = False - + self._turn_on = PluggableAction(self.async_write_ha_state) self._current_source = None self._source_list: dict = {} @@ -144,6 +146,13 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): """Connect and subscribe to dispatcher signals and state updates.""" await super().async_added_to_hass() + if (entry := self.registry_entry) and entry.device_id: + self.async_on_remove( + self._turn_on.async_register( + self.hass, async_get_turn_on_trigger(entry.device_id) + ) + ) + self.async_on_remove( async_dispatcher_connect(self.hass, DOMAIN, self.async_signal_handler) ) @@ -318,7 +327,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): @property def supported_features(self) -> MediaPlayerEntityFeature: """Flag media player features that are supported.""" - if self._wrapper.turn_on: + if self._turn_on: return self._supported_features | MediaPlayerEntityFeature.TURN_ON return self._supported_features @@ -330,7 +339,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): async def async_turn_on(self) -> None: """Turn on media player.""" - self._wrapper.turn_on.async_run(self.hass, self._context) + await self._turn_on.async_run(self.hass, self._context) @cmd async def async_volume_up(self) -> None: diff --git a/homeassistant/components/webostv/triggers/turn_on.py b/homeassistant/components/webostv/triggers/turn_on.py index 806b0b4b964..403219f1372 100644 --- a/homeassistant/components/webostv/triggers/turn_on.py +++ b/homeassistant/components/webostv/triggers/turn_on.py @@ -3,15 +3,25 @@ from __future__ import annotations import voluptuous as vol -from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM +from homeassistant.const import ( + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + CONF_DEVICE_ID, + CONF_DOMAIN, + CONF_PLATFORM, + CONF_TYPE, +) from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo +from homeassistant.helpers.trigger import ( + PluggableAction, + TriggerActionType, + TriggerInfo, +) from homeassistant.helpers.typing import ConfigType from ..const import DOMAIN from ..helpers import ( - async_get_client_wrapper_by_device_entry, async_get_device_entry_by_device_id, async_get_device_id_from_entity_id, ) @@ -33,6 +43,17 @@ TRIGGER_SCHEMA = vol.All( ) +def async_get_turn_on_trigger(device_id: str) -> dict[str, str]: + """Return data for a turn on trigger.""" + + return { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: PLATFORM_TYPE, + } + + async def async_attach_trigger( hass: HomeAssistant, config: ConfigType, @@ -69,10 +90,12 @@ async def async_attach_trigger( "description": f"webostv turn on trigger for {device_name}", } - client_wrapper = async_get_client_wrapper_by_device_entry(hass, device) + turn_on_trigger = async_get_turn_on_trigger(device_id) unsubs.append( - client_wrapper.turn_on.async_attach(action, {"trigger": variables}) + PluggableAction.async_attach_trigger( + hass, turn_on_trigger, action, {"trigger": variables} + ) ) @callback diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 4cb724a6435..1054521ee51 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -2,7 +2,9 @@ from __future__ import annotations import asyncio +from collections import defaultdict from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field import functools import logging from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast @@ -19,6 +21,7 @@ from homeassistant.const import ( from homeassistant.core import ( CALLBACK_TYPE, Context, + HassJob, HomeAssistant, callback, is_callback, @@ -38,6 +41,8 @@ _PLATFORM_ALIASES = { "homeassistant": ("event", "numeric_state", "state", "time_pattern", "time"), } +DATA_PLUGGABLE_ACTIONS = "pluggable_actions" + class TriggerActionType(Protocol): """Protocol type for trigger action callback.""" @@ -68,6 +73,110 @@ class TriggerInfo(TypedDict): trigger_data: TriggerData +@dataclass +class PluggableActionsEntry: + """Holder to keep track of all plugs and actions for a given trigger.""" + + plugs: set[PluggableAction] = field(default_factory=set) + actions: dict[object, tuple[HassJob, dict[str, Any]]] = field(default_factory=dict) + + +class PluggableAction: + """A pluggable action handler.""" + + _entry: PluggableActionsEntry | None = None + + def __init__(self, update: CALLBACK_TYPE | None = None) -> None: + """Initialize a pluggable action. + + :param update: callback triggered whenever triggers are attached or removed. + """ + self._update = update + + def __bool__(self) -> bool: + """Return if we have something attached.""" + return bool(self._entry and self._entry.actions) + + @callback + def async_run_update(self) -> None: + """Run update function if one exists.""" + if self._update: + self._update() + + @staticmethod + @callback + def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]: + """Return the pluggable actions registry.""" + if data := hass.data.get(DATA_PLUGGABLE_ACTIONS): + return data # type: ignore[no-any-return] + data = defaultdict(PluggableActionsEntry) + hass.data[DATA_PLUGGABLE_ACTIONS] = data + return data + + @staticmethod + @callback + def async_attach_trigger( + hass: HomeAssistant, + trigger: dict[str, str], + action: TriggerActionType, + variables: dict[str, Any], + ) -> CALLBACK_TYPE: + """Attach an action to a trigger entry. Existing or future plugs registered will be attached.""" + reg = PluggableAction.async_get_registry(hass) + key = tuple(sorted(trigger.items())) + entry = reg[key] + + def _update() -> None: + for plug in entry.plugs: + plug.async_run_update() + + @callback + def _remove() -> None: + """Remove this action attachment, and disconnect all plugs.""" + del entry.actions[_remove] + _update() + if not entry.actions and not entry.plugs: + del reg[key] + + job = HassJob(action) + entry.actions[_remove] = (job, variables) + _update() + + return _remove + + @callback + def async_register( + self, hass: HomeAssistant, trigger: dict[str, str] + ) -> CALLBACK_TYPE: + """Register plug in the global plugs dictionary.""" + + reg = PluggableAction.async_get_registry(hass) + key = tuple(sorted(trigger.items())) + self._entry = reg[key] + self._entry.plugs.add(self) + + @callback + def _remove() -> None: + """Remove plug from registration, and clean up entry if there are no actions or plugs registered.""" + assert self._entry + self._entry.plugs.remove(self) + if not self._entry.actions and not self._entry.plugs: + del reg[key] + self._entry = None + + return _remove + + async def async_run( + self, hass: HomeAssistant, context: Context | None = None + ) -> None: + """Run all actions.""" + assert self._entry + for job, variables in self._entry.actions.values(): + task = hass.async_run_hass_job(job, variables, context) + if task: + await task + + async def _async_get_trigger_platform( hass: HomeAssistant, config: ConfigType ) -> DeviceAutomationTriggerProtocol: diff --git a/tests/components/philips_js/conftest.py b/tests/components/philips_js/conftest.py index e0069cf9b75..dfda844c7a2 100644 --- a/tests/components/philips_js/conftest.py +++ b/tests/components/philips_js/conftest.py @@ -31,6 +31,10 @@ def mock_tv(): tv.notify_change_supported = False tv.pairing_type = None tv.powerstate = None + tv.source_id = None + tv.ambilight_current_configuration = None + tv.ambilight_styles = {} + tv.ambilight_cached = {} with patch( "homeassistant.components.philips_js.config_flow.PhilipsTV", return_value=tv @@ -42,7 +46,7 @@ def mock_tv(): async def mock_config_entry(hass): """Get standard player.""" config_entry = MockConfigEntry( - domain=DOMAIN, data=MOCK_CONFIG, title=MOCK_NAME, unique_id="ABCDEFGHIJKLF" + domain=DOMAIN, data=MOCK_CONFIG, title=MOCK_NAME, unique_id=MOCK_SERIAL_NO ) config_entry.add_to_hass(hass) return config_entry diff --git a/tests/components/philips_js/test_device_trigger.py b/tests/components/philips_js/test_device_trigger.py index dd06ee25d49..2c5b21b1e34 100644 --- a/tests/components/philips_js/test_device_trigger.py +++ b/tests/components/philips_js/test_device_trigger.py @@ -34,6 +34,7 @@ async def test_get_triggers(hass, mock_device): triggers = await async_get_device_automations( hass, DeviceAutomationType.TRIGGER, mock_device.id ) + triggers = [trigger for trigger in triggers if trigger["domain"] == DOMAIN] assert_lists_same(triggers, expected_triggers) diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 9cd3b0956ce..4718e3130d5 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -1,11 +1,13 @@ """The tests for the trigger helper.""" -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest import voluptuous as vol -from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.core import Context, HomeAssistant, ServiceCall, callback from homeassistant.helpers.trigger import ( + DATA_PLUGGABLE_ACTIONS, + PluggableAction, _async_get_trigger_platform, async_initialize_triggers, async_validate_trigger_config, @@ -197,3 +199,80 @@ async def test_async_initialize_triggers( log_cb.reset_mock() unsub() + + +async def test_pluggable_action(hass: HomeAssistant, calls: list[ServiceCall]): + """Test normal behavior of pluggable actions.""" + update_1 = MagicMock() + update_2 = MagicMock() + action_1 = AsyncMock() + action_2 = AsyncMock() + trigger_1 = {"domain": "test", "device": "1"} + trigger_2 = {"domain": "test", "device": "2"} + variables_1 = {"source": "test 1"} + variables_2 = {"source": "test 2"} + context_1 = Context() + context_2 = Context() + + plug_1 = PluggableAction(update_1) + plug_2 = PluggableAction(update_2) + + # Verify plug is inactive without triggers + remove_plug_1 = plug_1.async_register(hass, trigger_1) + assert not plug_1 + assert not plug_2 + + # Verify plug remain inactive with non matching trigger + remove_attach_2 = PluggableAction.async_attach_trigger( + hass, trigger_2, action_2, variables_2 + ) + assert not plug_1 + assert not plug_2 + update_1.assert_not_called() + update_2.assert_not_called() + + # Verify plug is active, and update when matching trigger attaches + remove_attach_1 = PluggableAction.async_attach_trigger( + hass, trigger_1, action_1, variables_1 + ) + assert plug_1 + assert not plug_2 + update_1.assert_called() + update_1.reset_mock() + update_2.assert_not_called() + + # Verify a non registered plug is inactive + remove_plug_1() + assert not plug_1 + assert not plug_2 + + # Verify a plug registered to existing trigger is true + remove_plug_1 = plug_1.async_register(hass, trigger_1) + assert plug_1 + assert not plug_2 + + remove_plug_2 = plug_2.async_register(hass, trigger_2) + assert plug_1 + assert plug_2 + + # Verify no actions should have been triggered so far + action_1.assert_not_called() + action_2.assert_not_called() + + # Verify action is triggered with correct data + await plug_1.async_run(hass, context_1) + await plug_2.async_run(hass, context_2) + action_1.assert_called_with(variables_1, context_1) + action_2.assert_called_with(variables_2, context_2) + + # Verify plug goes inactive if trigger is removed + remove_attach_1() + assert not plug_1 + + # Verify registry is cleaned when no plugs nor triggers are attached + assert hass.data[DATA_PLUGGABLE_ACTIONS] + remove_plug_1() + remove_plug_2() + remove_attach_2() + assert not hass.data[DATA_PLUGGABLE_ACTIONS] + assert not plug_2