Move PluggableAction to trigger helpers (#81900)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Joakim Plate <elupus@ecce.se>
Co-authored-by: Shay Levy <levyshay1@gmail.com>
This commit is contained in:
epenet 2022-11-22 11:35:18 +01:00 committed by GitHub
parent b566d55998
commit 9d192643ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 304 additions and 147 deletions

View file

@ -961,6 +961,7 @@ omit =
homeassistant/components/pencom/switch.py homeassistant/components/pencom/switch.py
homeassistant/components/philips_js/__init__.py homeassistant/components/philips_js/__init__.py
homeassistant/components/philips_js/diagnostics.py homeassistant/components/philips_js/diagnostics.py
homeassistant/components/philips_js/helpers.py
homeassistant/components/philips_js/light.py homeassistant/components/philips_js/light.py
homeassistant/components/philips_js/media_player.py homeassistant/components/philips_js/media_player.py
homeassistant/components/philips_js/remote.py homeassistant/components/philips_js/remote.py

View file

@ -2,10 +2,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Mapping from collections.abc import Mapping
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any
from haphilipsjs import AutenticationFailure, ConnectionFailure, PhilipsTV from haphilipsjs import AutenticationFailure, ConnectionFailure, PhilipsTV
from haphilipsjs.typing import SystemType from haphilipsjs.typing import SystemType
@ -18,9 +17,8 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
Platform, Platform,
) )
from homeassistant.core import Context, HassJob, HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.trigger import TriggerActionType
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import CONF_ALLOW_NOTIFY, CONF_SYSTEM, DOMAIN from .const import CONF_ALLOW_NOTIFY, CONF_SYSTEM, DOMAIN
@ -78,42 +76,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return unload_ok return unload_ok
class PluggableAction:
"""A pluggable action handler."""
def __init__(self, update: Callable[[], None]) -> None:
"""Initialize."""
self._update = update
self._actions: dict[
Any, tuple[HassJob[..., Coroutine[Any, Any, None]], dict[str, Any]]
] = {}
def __bool__(self):
"""Return if we have something attached."""
return bool(self._actions)
@callback
def async_attach(self, action: TriggerActionType, variables: dict[str, Any]):
"""Attach a device trigger for turn on."""
@callback
def _remove():
del self._actions[_remove]
self._update()
job = HassJob(action)
self._actions[_remove] = (job, variables)
self._update()
return _remove
async def async_run(self, hass: HomeAssistant, context: Context | None = None):
"""Run all turn on triggers."""
for job, variables in self._actions.values():
hass.async_run_hass_job(job, variables, context)
class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]): class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
"""Coordinator to update data.""" """Coordinator to update data."""
@ -125,8 +87,6 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
self.options = options self.options = options
self._notify_future: asyncio.Task | None = None self._notify_future: asyncio.Task | None = None
self.turn_on = PluggableAction(self.async_update_listeners)
super().__init__( super().__init__(
hass, hass,
LOGGER, LOGGER,

View file

@ -6,3 +6,5 @@ CONF_ALLOW_NOTIFY = "allow_notify"
CONST_APP_ID = "homeassistant.io" CONST_APP_ID = "homeassistant.io"
CONST_APP_NAME = "Home Assistant" CONST_APP_NAME = "Home Assistant"
TRIGGER_TYPE_TURN_ON = "turn_on"

View file

@ -4,17 +4,18 @@ from __future__ import annotations
import voluptuous as vol import voluptuous as vol
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE from homeassistant.const import CONF_DEVICE_ID, CONF_TYPE
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr from homeassistant.helpers.trigger import (
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo PluggableAction,
TriggerActionType,
TriggerInfo,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import PhilipsTVDataUpdateCoordinator from .const import DOMAIN, TRIGGER_TYPE_TURN_ON
from .const import DOMAIN from .helpers import async_get_turn_on_trigger
TRIGGER_TYPE_TURN_ON = "turn_on"
TRIGGER_TYPES = {TRIGGER_TYPE_TURN_ON} TRIGGER_TYPES = {TRIGGER_TYPE_TURN_ON}
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend(
@ -29,14 +30,7 @@ async def async_get_triggers(
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""List device triggers for device.""" """List device triggers for device."""
triggers = [] triggers = []
triggers.append( triggers.append(async_get_turn_on_trigger(device_id))
{
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
CONF_TYPE: TRIGGER_TYPE_TURN_ON,
}
)
return triggers return triggers
@ -49,7 +43,6 @@ async def async_attach_trigger(
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
trigger_data = trigger_info["trigger_data"] trigger_data = trigger_info["trigger_data"]
registry: dr.DeviceRegistry = dr.async_get(hass)
if (trigger_type := config[CONF_TYPE]) == TRIGGER_TYPE_TURN_ON: if (trigger_type := config[CONF_TYPE]) == TRIGGER_TYPE_TURN_ON:
variables = { variables = {
"trigger": { "trigger": {
@ -61,16 +54,9 @@ async def async_attach_trigger(
} }
} }
device = registry.async_get(config[CONF_DEVICE_ID]) turn_on_trigger = async_get_turn_on_trigger(config[CONF_DEVICE_ID])
if device is None: return PluggableAction.async_attach_trigger(
raise HomeAssistantError( hass, turn_on_trigger, action, variables
f"Device id {config[CONF_DEVICE_ID]} not found in registry" )
)
for config_entry_id in device.config_entries:
coordinator: PhilipsTVDataUpdateCoordinator = hass.data[DOMAIN].get(
config_entry_id
)
if coordinator:
return coordinator.turn_on.async_attach(action, variables)
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}") raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")

View file

@ -0,0 +1,16 @@
"""Helpers for philips_js."""
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE
from .const import DOMAIN, TRIGGER_TYPE_TURN_ON
def async_get_turn_on_trigger(device_id: str) -> dict[str, str]:
"""Return trigger description for a turn on trigger."""
return {
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
CONF_TYPE: TRIGGER_TYPE_TURN_ON,
}

View file

@ -19,10 +19,12 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.trigger import PluggableAction
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import LOGGER as _LOGGER, PhilipsTVDataUpdateCoordinator from . import LOGGER as _LOGGER, PhilipsTVDataUpdateCoordinator
from .const import DOMAIN from .const import DOMAIN
from .helpers import async_get_turn_on_trigger
SUPPORT_PHILIPS_JS = ( SUPPORT_PHILIPS_JS = (
MediaPlayerEntityFeature.TURN_OFF MediaPlayerEntityFeature.TURN_OFF
@ -39,8 +41,6 @@ SUPPORT_PHILIPS_JS = (
| MediaPlayerEntityFeature.STOP | MediaPlayerEntityFeature.STOP
) )
CONF_ON_ACTION = "turn_on_action"
def _inverted(data): def _inverted(data):
return {v: k for k, v in data.items()} return {v: k for k, v in data.items()}
@ -95,9 +95,19 @@ class PhilipsTVMediaPlayer(
self._media_title: str | None = None self._media_title: str | None = None
self._media_channel: str | None = None self._media_channel: str | None = None
self._turn_on = PluggableAction(self.async_write_ha_state)
super().__init__(coordinator) super().__init__(coordinator)
self._update_from_coordinator() self._update_from_coordinator()
async def async_added_to_hass(self) -> None:
"""Handle being added to hass."""
if (entry := self.registry_entry) and entry.device_id:
self.async_on_remove(
self._turn_on.async_register(
self.hass, async_get_turn_on_trigger(entry.device_id)
)
)
async def _async_update_soon(self): async def _async_update_soon(self):
"""Reschedule update task.""" """Reschedule update task."""
self.async_write_ha_state() self.async_write_ha_state()
@ -107,9 +117,7 @@ class PhilipsTVMediaPlayer(
def supported_features(self) -> MediaPlayerEntityFeature: def supported_features(self) -> MediaPlayerEntityFeature:
"""Flag media player features that are supported.""" """Flag media player features that are supported."""
supports = self._supports supports = self._supports
if self.coordinator.turn_on or ( if self._turn_on or (self._tv.on and self._tv.powerstate is not None):
self._tv.on and self._tv.powerstate is not None
):
supports |= MediaPlayerEntityFeature.TURN_ON supports |= MediaPlayerEntityFeature.TURN_ON
return supports return supports
@ -152,7 +160,7 @@ class PhilipsTVMediaPlayer(
await self._tv.setPowerState("On") await self._tv.setPowerState("On")
self._state = MediaPlayerState.ON self._state = MediaPlayerState.ON
else: else:
await self.coordinator.turn_on.async_run(self.hass, self._context) await self._turn_on.async_run(self.hass, self._context)
await self._async_update_soon() await self._async_update_soon()
async def async_turn_off(self) -> None: async def async_turn_off(self) -> None:

View file

@ -13,10 +13,12 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.trigger import PluggableAction
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import LOGGER, PhilipsTVDataUpdateCoordinator from . import LOGGER, PhilipsTVDataUpdateCoordinator
from .const import DOMAIN from .const import DOMAIN
from .helpers import async_get_turn_on_trigger
async def async_setup_entry( async def async_setup_entry(
@ -52,6 +54,16 @@ class PhilipsTVRemote(CoordinatorEntity[PhilipsTVDataUpdateCoordinator], RemoteE
name=coordinator.system["name"], name=coordinator.system["name"],
sw_version=coordinator.system.get("softwareversion"), sw_version=coordinator.system.get("softwareversion"),
) )
self._turn_on = PluggableAction(self.async_write_ha_state)
async def async_added_to_hass(self) -> None:
"""Handle being added to hass."""
if (entry := self.registry_entry) and entry.device_id:
self.async_on_remove(
self._turn_on.async_register(
self.hass, async_get_turn_on_trigger(entry.device_id)
)
)
@property @property
def is_on(self): def is_on(self):
@ -65,7 +77,7 @@ class PhilipsTVRemote(CoordinatorEntity[PhilipsTVDataUpdateCoordinator], RemoteE
if self._tv.on and self._tv.powerstate: if self._tv.on and self._tv.powerstate:
await self._tv.setPowerState("On") await self._tv.setPowerState("On")
else: else:
await self.coordinator.turn_on.async_run(self.hass, self._context) await self._turn_on.async_run(self.hass, self._context)
self.async_write_ha_state() self.async_write_ha_state()
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:

View file

@ -1,10 +1,8 @@
"""Support for LG webOS Smart TV.""" """Support for LG webOS Smart TV."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine
from contextlib import suppress from contextlib import suppress
import logging import logging
from typing import Any
from aiowebostv import WebOsClient, WebOsTvPairError from aiowebostv import WebOsClient, WebOsTvPairError
import voluptuous as vol import voluptuous as vol
@ -19,17 +17,9 @@ from homeassistant.const import (
CONF_NAME, CONF_NAME,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.core import ( from homeassistant.core import Event, HomeAssistant, ServiceCall
Context,
Event,
HassJob,
HomeAssistant,
ServiceCall,
callback,
)
from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.trigger import TriggerActionType
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
@ -165,43 +155,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return unload_ok return unload_ok
class PluggableAction:
"""A pluggable action handler."""
def __init__(self) -> None:
"""Initialize."""
self._actions: dict[
Callable[[], None],
tuple[HassJob[..., Coroutine[Any, Any, None]], dict[str, Any]],
] = {}
def __bool__(self) -> bool:
"""Return if we have something attached."""
return bool(self._actions)
@callback
def async_attach(
self, action: TriggerActionType, variables: dict[str, Any]
) -> Callable[[], None]:
"""Attach a device trigger for turn on."""
@callback
def _remove() -> None:
del self._actions[_remove]
job = HassJob(action)
self._actions[_remove] = (job, variables)
return _remove
@callback
def async_run(self, hass: HomeAssistant, context: Context | None = None) -> None:
"""Run all turn on triggers."""
for job, variables in self._actions.values():
hass.async_run_hass_job(job, variables, context)
class WebOsClientWrapper: class WebOsClientWrapper:
"""Wrapper for a WebOS TV client with Home Assistant specific functions.""" """Wrapper for a WebOS TV client with Home Assistant specific functions."""
@ -209,7 +162,6 @@ class WebOsClientWrapper:
"""Set up the client.""" """Set up the client."""
self.host = host self.host = host
self.client_key = client_key self.client_key = client_key
self.turn_on = PluggableAction()
self.client: WebOsClient | None = None self.client: WebOsClient | None = None
async def connect(self) -> None: async def connect(self) -> None:

View file

@ -7,7 +7,7 @@ from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEM
from homeassistant.components.device_automation.exceptions import ( from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig, InvalidDeviceAutomationConfig,
) )
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE from homeassistant.const import CONF_DEVICE_ID, CONF_PLATFORM, CONF_TYPE
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
@ -19,7 +19,10 @@ from .helpers import (
async_get_client_wrapper_by_device_entry, async_get_client_wrapper_by_device_entry,
async_get_device_entry_by_device_id, async_get_device_entry_by_device_id,
) )
from .triggers.turn_on import PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE from .triggers.turn_on import (
PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE,
async_get_turn_on_trigger,
)
TRIGGER_TYPES = {TURN_ON_PLATFORM_TYPE} TRIGGER_TYPES = {TURN_ON_PLATFORM_TYPE}
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend(
@ -51,15 +54,7 @@ async def async_get_triggers(
_hass: HomeAssistant, device_id: str _hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""List device triggers for device.""" """List device triggers for device."""
triggers = [] triggers = [async_get_turn_on_trigger(device_id)]
base_trigger = {
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
}
triggers.append({**base_trigger, CONF_TYPE: TURN_ON_PLATFORM_TYPE})
return triggers return triggers

View file

@ -32,6 +32,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.trigger import PluggableAction
from . import WebOsClientWrapper from . import WebOsClientWrapper
from .const import ( from .const import (
@ -43,6 +44,7 @@ from .const import (
LIVE_TV_APP_ID, LIVE_TV_APP_ID,
WEBOSTV_EXCEPTIONS, WEBOSTV_EXCEPTIONS,
) )
from .triggers.turn_on import async_get_turn_on_trigger
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -133,7 +135,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
# Assume that the TV is not paused # Assume that the TV is not paused
self._paused = False self._paused = False
self._turn_on = PluggableAction(self.async_write_ha_state)
self._current_source = None self._current_source = None
self._source_list: dict = {} self._source_list: dict = {}
@ -144,6 +146,13 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
"""Connect and subscribe to dispatcher signals and state updates.""" """Connect and subscribe to dispatcher signals and state updates."""
await super().async_added_to_hass() await super().async_added_to_hass()
if (entry := self.registry_entry) and entry.device_id:
self.async_on_remove(
self._turn_on.async_register(
self.hass, async_get_turn_on_trigger(entry.device_id)
)
)
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect(self.hass, DOMAIN, self.async_signal_handler) async_dispatcher_connect(self.hass, DOMAIN, self.async_signal_handler)
) )
@ -318,7 +327,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
@property @property
def supported_features(self) -> MediaPlayerEntityFeature: def supported_features(self) -> MediaPlayerEntityFeature:
"""Flag media player features that are supported.""" """Flag media player features that are supported."""
if self._wrapper.turn_on: if self._turn_on:
return self._supported_features | MediaPlayerEntityFeature.TURN_ON return self._supported_features | MediaPlayerEntityFeature.TURN_ON
return self._supported_features return self._supported_features
@ -330,7 +339,7 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
async def async_turn_on(self) -> None: async def async_turn_on(self) -> None:
"""Turn on media player.""" """Turn on media player."""
self._wrapper.turn_on.async_run(self.hass, self._context) await self._turn_on.async_run(self.hass, self._context)
@cmd @cmd
async def async_volume_up(self) -> None: async def async_volume_up(self) -> None:

View file

@ -3,15 +3,25 @@ from __future__ import annotations
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM from homeassistant.const import (
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_PLATFORM,
CONF_TYPE,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import (
PluggableAction,
TriggerActionType,
TriggerInfo,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from ..const import DOMAIN from ..const import DOMAIN
from ..helpers import ( from ..helpers import (
async_get_client_wrapper_by_device_entry,
async_get_device_entry_by_device_id, async_get_device_entry_by_device_id,
async_get_device_id_from_entity_id, async_get_device_id_from_entity_id,
) )
@ -33,6 +43,17 @@ TRIGGER_SCHEMA = vol.All(
) )
def async_get_turn_on_trigger(device_id: str) -> dict[str, str]:
"""Return data for a turn on trigger."""
return {
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
CONF_TYPE: PLATFORM_TYPE,
}
async def async_attach_trigger( async def async_attach_trigger(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
@ -69,10 +90,12 @@ async def async_attach_trigger(
"description": f"webostv turn on trigger for {device_name}", "description": f"webostv turn on trigger for {device_name}",
} }
client_wrapper = async_get_client_wrapper_by_device_entry(hass, device) turn_on_trigger = async_get_turn_on_trigger(device_id)
unsubs.append( unsubs.append(
client_wrapper.turn_on.async_attach(action, {"trigger": variables}) PluggableAction.async_attach_trigger(
hass, turn_on_trigger, action, {"trigger": variables}
)
) )
@callback @callback

View file

@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
import functools import functools
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
@ -19,6 +21,7 @@ from homeassistant.const import (
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Context, Context,
HassJob,
HomeAssistant, HomeAssistant,
callback, callback,
is_callback, is_callback,
@ -38,6 +41,8 @@ _PLATFORM_ALIASES = {
"homeassistant": ("event", "numeric_state", "state", "time_pattern", "time"), "homeassistant": ("event", "numeric_state", "state", "time_pattern", "time"),
} }
DATA_PLUGGABLE_ACTIONS = "pluggable_actions"
class TriggerActionType(Protocol): class TriggerActionType(Protocol):
"""Protocol type for trigger action callback.""" """Protocol type for trigger action callback."""
@ -68,6 +73,110 @@ class TriggerInfo(TypedDict):
trigger_data: TriggerData trigger_data: TriggerData
@dataclass
class PluggableActionsEntry:
"""Holder to keep track of all plugs and actions for a given trigger."""
plugs: set[PluggableAction] = field(default_factory=set)
actions: dict[object, tuple[HassJob, dict[str, Any]]] = field(default_factory=dict)
class PluggableAction:
"""A pluggable action handler."""
_entry: PluggableActionsEntry | None = None
def __init__(self, update: CALLBACK_TYPE | None = None) -> None:
"""Initialize a pluggable action.
:param update: callback triggered whenever triggers are attached or removed.
"""
self._update = update
def __bool__(self) -> bool:
"""Return if we have something attached."""
return bool(self._entry and self._entry.actions)
@callback
def async_run_update(self) -> None:
"""Run update function if one exists."""
if self._update:
self._update()
@staticmethod
@callback
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
"""Return the pluggable actions registry."""
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
return data # type: ignore[no-any-return]
data = defaultdict(PluggableActionsEntry)
hass.data[DATA_PLUGGABLE_ACTIONS] = data
return data
@staticmethod
@callback
def async_attach_trigger(
hass: HomeAssistant,
trigger: dict[str, str],
action: TriggerActionType,
variables: dict[str, Any],
) -> CALLBACK_TYPE:
"""Attach an action to a trigger entry. Existing or future plugs registered will be attached."""
reg = PluggableAction.async_get_registry(hass)
key = tuple(sorted(trigger.items()))
entry = reg[key]
def _update() -> None:
for plug in entry.plugs:
plug.async_run_update()
@callback
def _remove() -> None:
"""Remove this action attachment, and disconnect all plugs."""
del entry.actions[_remove]
_update()
if not entry.actions and not entry.plugs:
del reg[key]
job = HassJob(action)
entry.actions[_remove] = (job, variables)
_update()
return _remove
@callback
def async_register(
self, hass: HomeAssistant, trigger: dict[str, str]
) -> CALLBACK_TYPE:
"""Register plug in the global plugs dictionary."""
reg = PluggableAction.async_get_registry(hass)
key = tuple(sorted(trigger.items()))
self._entry = reg[key]
self._entry.plugs.add(self)
@callback
def _remove() -> None:
"""Remove plug from registration, and clean up entry if there are no actions or plugs registered."""
assert self._entry
self._entry.plugs.remove(self)
if not self._entry.actions and not self._entry.plugs:
del reg[key]
self._entry = None
return _remove
async def async_run(
self, hass: HomeAssistant, context: Context | None = None
) -> None:
"""Run all actions."""
assert self._entry
for job, variables in self._entry.actions.values():
task = hass.async_run_hass_job(job, variables, context)
if task:
await task
async def _async_get_trigger_platform( async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> DeviceAutomationTriggerProtocol: ) -> DeviceAutomationTriggerProtocol:

View file

@ -31,6 +31,10 @@ def mock_tv():
tv.notify_change_supported = False tv.notify_change_supported = False
tv.pairing_type = None tv.pairing_type = None
tv.powerstate = None tv.powerstate = None
tv.source_id = None
tv.ambilight_current_configuration = None
tv.ambilight_styles = {}
tv.ambilight_cached = {}
with patch( with patch(
"homeassistant.components.philips_js.config_flow.PhilipsTV", return_value=tv "homeassistant.components.philips_js.config_flow.PhilipsTV", return_value=tv
@ -42,7 +46,7 @@ def mock_tv():
async def mock_config_entry(hass): async def mock_config_entry(hass):
"""Get standard player.""" """Get standard player."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
domain=DOMAIN, data=MOCK_CONFIG, title=MOCK_NAME, unique_id="ABCDEFGHIJKLF" domain=DOMAIN, data=MOCK_CONFIG, title=MOCK_NAME, unique_id=MOCK_SERIAL_NO
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
return config_entry return config_entry

View file

@ -34,6 +34,7 @@ async def test_get_triggers(hass, mock_device):
triggers = await async_get_device_automations( triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, mock_device.id hass, DeviceAutomationType.TRIGGER, mock_device.id
) )
triggers = [trigger for trigger in triggers if trigger["domain"] == DOMAIN]
assert_lists_same(triggers, expected_triggers) assert_lists_same(triggers, expected_triggers)

View file

@ -1,11 +1,13 @@
"""The tests for the trigger helper.""" """The tests for the trigger helper."""
from unittest.mock import ANY, MagicMock, call, patch from unittest.mock import ANY, AsyncMock, MagicMock, call, patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import Context, HomeAssistant, ServiceCall, callback
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS,
PluggableAction,
_async_get_trigger_platform, _async_get_trigger_platform,
async_initialize_triggers, async_initialize_triggers,
async_validate_trigger_config, async_validate_trigger_config,
@ -197,3 +199,80 @@ async def test_async_initialize_triggers(
log_cb.reset_mock() log_cb.reset_mock()
unsub() unsub()
async def test_pluggable_action(hass: HomeAssistant, calls: list[ServiceCall]):
"""Test normal behavior of pluggable actions."""
update_1 = MagicMock()
update_2 = MagicMock()
action_1 = AsyncMock()
action_2 = AsyncMock()
trigger_1 = {"domain": "test", "device": "1"}
trigger_2 = {"domain": "test", "device": "2"}
variables_1 = {"source": "test 1"}
variables_2 = {"source": "test 2"}
context_1 = Context()
context_2 = Context()
plug_1 = PluggableAction(update_1)
plug_2 = PluggableAction(update_2)
# Verify plug is inactive without triggers
remove_plug_1 = plug_1.async_register(hass, trigger_1)
assert not plug_1
assert not plug_2
# Verify plug remain inactive with non matching trigger
remove_attach_2 = PluggableAction.async_attach_trigger(
hass, trigger_2, action_2, variables_2
)
assert not plug_1
assert not plug_2
update_1.assert_not_called()
update_2.assert_not_called()
# Verify plug is active, and update when matching trigger attaches
remove_attach_1 = PluggableAction.async_attach_trigger(
hass, trigger_1, action_1, variables_1
)
assert plug_1
assert not plug_2
update_1.assert_called()
update_1.reset_mock()
update_2.assert_not_called()
# Verify a non registered plug is inactive
remove_plug_1()
assert not plug_1
assert not plug_2
# Verify a plug registered to existing trigger is true
remove_plug_1 = plug_1.async_register(hass, trigger_1)
assert plug_1
assert not plug_2
remove_plug_2 = plug_2.async_register(hass, trigger_2)
assert plug_1
assert plug_2
# Verify no actions should have been triggered so far
action_1.assert_not_called()
action_2.assert_not_called()
# Verify action is triggered with correct data
await plug_1.async_run(hass, context_1)
await plug_2.async_run(hass, context_2)
action_1.assert_called_with(variables_1, context_1)
action_2.assert_called_with(variables_2, context_2)
# Verify plug goes inactive if trigger is removed
remove_attach_1()
assert not plug_1
# Verify registry is cleaned when no plugs nor triggers are attached
assert hass.data[DATA_PLUGGABLE_ACTIONS]
remove_plug_1()
remove_plug_2()
remove_attach_2()
assert not hass.data[DATA_PLUGGABLE_ACTIONS]
assert not plug_2