Add turn_on trigger to Samsung TV (#89018)

* Add turn_on trigger to Samsung TV

* Add tests

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove assert

* Cleanup mock_send_magic_packet

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
epenet 2023-03-15 12:43:53 +01:00 committed by GitHub
parent cd23caff58
commit 6270776fbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 675 additions and 2 deletions

View file

@ -0,0 +1,80 @@
"""Provides device automations for control of Samsung TV."""
from __future__ import annotations
import voluptuous as vol
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig,
)
from homeassistant.const import CONF_DEVICE_ID, CONF_PLATFORM, CONF_TYPE
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from . import trigger
from .const import DOMAIN
from .helpers import (
async_get_client_by_device_entry,
async_get_device_entry_by_device_id,
)
from .triggers.turn_on import (
PLATFORM_TYPE as TURN_ON_PLATFORM_TYPE,
async_get_turn_on_trigger,
)
TRIGGER_TYPES = {TURN_ON_PLATFORM_TYPE}
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_TYPE): vol.In(TRIGGER_TYPES),
}
)
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
config = TRIGGER_SCHEMA(config)
if config[CONF_TYPE] == TURN_ON_PLATFORM_TYPE:
device_id = config[CONF_DEVICE_ID]
try:
device = async_get_device_entry_by_device_id(hass, device_id)
if DOMAIN in hass.data:
async_get_client_by_device_entry(hass, device)
except ValueError as err:
raise InvalidDeviceAutomationConfig(err) from err
return config
async def async_get_triggers(
_hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device triggers for device."""
triggers = [async_get_turn_on_trigger(device_id)]
return triggers
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
if (trigger_type := config[CONF_TYPE]) == TURN_ON_PLATFORM_TYPE:
trigger_config = {
CONF_PLATFORM: trigger_type,
CONF_DEVICE_ID: config[CONF_DEVICE_ID],
}
trigger_config = await trigger.async_validate_trigger_config(
hass, trigger_config
)
return await trigger.async_attach_trigger(
hass, trigger_config, action, trigger_info
)
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")

View file

@ -0,0 +1,61 @@
"""Helper functions for Samsung TV."""
from __future__ import annotations
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceEntry
from .bridge import SamsungTVBridge
from .const import DOMAIN
@callback
def async_get_device_entry_by_device_id(
hass: HomeAssistant, device_id: str
) -> DeviceEntry:
"""Get Device Entry from Device Registry by device ID.
Raises ValueError if device ID is invalid.
"""
device_reg = dr.async_get(hass)
if (device := device_reg.async_get(device_id)) is None:
raise ValueError(f"Device {device_id} is not a valid {DOMAIN} device.")
return device
@callback
def async_get_device_id_from_entity_id(hass: HomeAssistant, entity_id: str) -> str:
"""Get device ID from an entity ID.
Raises ValueError if entity or device ID is invalid.
"""
ent_reg = er.async_get(hass)
entity_entry = ent_reg.async_get(entity_id)
if (
entity_entry is None
or entity_entry.device_id is None
or entity_entry.platform != DOMAIN
):
raise ValueError(f"Entity {entity_id} is not a valid {DOMAIN} entity.")
return entity_entry.device_id
@callback
def async_get_client_by_device_entry(
hass: HomeAssistant, device: DeviceEntry
) -> SamsungTVBridge:
"""Get SamsungTVBridge from Device Registry by device entry.
Raises ValueError if client is not found.
"""
domain_data: dict[str, SamsungTVBridge] = hass.data[DOMAIN]
for config_entry_id in device.config_entries:
if bridge := domain_data.get(config_entry_id):
return bridge
raise ValueError(
f"Device {device.id} is not from an existing {DOMAIN} config entry"
)

View file

@ -40,6 +40,7 @@ from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.trigger import PluggableAction
from homeassistant.util import dt as dt_util
from .bridge import SamsungTVBridge, SamsungTVWSBridge
@ -51,6 +52,7 @@ from .const import (
DOMAIN,
LOGGER,
)
from .triggers.turn_on import async_get_turn_on_trigger
SOURCES = {"TV": "KEY_TV", "HDMI": "KEY_HDMI"}
@ -112,6 +114,7 @@ class SamsungTVDevice(MediaPlayerEntity):
self._ssdp_rendering_control_location: str | None = config_entry.data.get(
CONF_SSDP_RENDERING_CONTROL_LOCATION
)
self._turn_on = PluggableAction(self.async_write_ha_state)
self._on_script = on_script
# Assume that the TV is in Play mode
self._playing: bool = True
@ -125,8 +128,8 @@ class SamsungTVDevice(MediaPlayerEntity):
self._app_list_event: asyncio.Event = asyncio.Event()
self._attr_supported_features = SUPPORT_SAMSUNGTV
if self._on_script or self._mac:
# Add turn-on if on_script or mac is available
if self._turn_on or self._on_script or self._mac:
# Add turn-on if turn_on trigger or on_script YAML or mac is available
self._attr_supported_features |= MediaPlayerEntityFeature.TURN_ON
if self._ssdp_rendering_control_location:
self._attr_supported_features |= MediaPlayerEntityFeature.VOLUME_SET
@ -359,11 +362,23 @@ class SamsungTVDevice(MediaPlayerEntity):
return False
return (
self.state == MediaPlayerState.ON
or bool(self._turn_on)
or self._on_script is not None
or self._mac is not None
or self._power_off_in_progress()
)
async def async_added_to_hass(self) -> None:
"""Connect and subscribe to dispatcher signals and state updates."""
await super().async_added_to_hass()
if (entry := self.registry_entry) and entry.device_id:
self.async_on_remove(
self._turn_on.async_register(
self.hass, async_get_turn_on_trigger(entry.device_id)
)
)
async def async_turn_off(self) -> None:
"""Turn off media player."""
self._end_of_power_off = dt_util.utcnow() + SCAN_INTERVAL_PLUS_OFF_TIME
@ -448,6 +463,9 @@ class SamsungTVDevice(MediaPlayerEntity):
async def async_turn_on(self) -> None:
"""Turn the media player on."""
if self._turn_on:
await self._turn_on.async_run(self.hass, self._context)
# on_script is deprecated - replaced by turn_on trigger
if self._on_script:
await self._on_script.async_run(context=self._context)
elif self._mac:

View file

@ -39,5 +39,10 @@
"unknown": "[%key:common::config_flow::error::unknown%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
}
},
"device_automation": {
"trigger_type": {
"samsungtv.turn_on": "Device is requested to turn on"
}
}
}

View file

@ -0,0 +1,46 @@
"""Samsung TV trigger dispatcher."""
from __future__ import annotations
from typing import cast
from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.trigger import (
TriggerActionType,
TriggerInfo,
TriggerProtocol,
)
from homeassistant.helpers.typing import ConfigType
from .triggers import turn_on
TRIGGERS = {
"turn_on": turn_on,
}
def _get_trigger_platform(config: ConfigType) -> TriggerProtocol:
"""Return trigger platform."""
platform_split = config[CONF_PLATFORM].split(".", maxsplit=1)
if len(platform_split) < 2 or platform_split[1] not in TRIGGERS:
raise ValueError(f"Unknown Samsung TV trigger platform {config[CONF_PLATFORM]}")
return cast(TriggerProtocol, TRIGGERS[platform_split[1]])
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
platform = _get_trigger_platform(config)
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach trigger of specified platform."""
platform = _get_trigger_platform(config)
return await platform.async_attach_trigger(hass, config, action, trigger_info)

View file

@ -0,0 +1 @@
"""Samsung TV triggers."""

View file

@ -0,0 +1,108 @@
"""Samsung TV device turn on trigger."""
from __future__ import annotations
import voluptuous as vol
from homeassistant.const import (
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_PLATFORM,
CONF_TYPE,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.trigger import (
PluggableAction,
TriggerActionType,
TriggerInfo,
)
from homeassistant.helpers.typing import ConfigType
from ..const import DOMAIN
from ..helpers import (
async_get_device_entry_by_device_id,
async_get_device_id_from_entity_id,
)
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
TRIGGER_TYPE_TURN_ON = "turn_on"
TRIGGER_SCHEMA = vol.All(
cv.TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): PLATFORM_TYPE,
vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
},
),
cv.has_at_least_one_key(ATTR_ENTITY_ID, ATTR_DEVICE_ID),
)
def async_get_turn_on_trigger(device_id: str) -> dict[str, str]:
"""Return data for a turn on trigger."""
return {
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
CONF_TYPE: PLATFORM_TYPE,
}
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
*,
platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE | None:
"""Attach a trigger."""
device_ids = set()
if ATTR_DEVICE_ID in config:
device_ids.update(config.get(ATTR_DEVICE_ID, []))
if ATTR_ENTITY_ID in config:
device_ids.update(
{
async_get_device_id_from_entity_id(hass, entity_id)
for entity_id in config.get(ATTR_ENTITY_ID, [])
}
)
trigger_data = trigger_info["trigger_data"]
unsubs = []
for device_id in device_ids:
device = async_get_device_entry_by_device_id(hass, device_id)
device_name = device.name_by_user or device.name
variables = {
**trigger_data,
CONF_PLATFORM: platform_type,
ATTR_DEVICE_ID: device_id,
"description": f"Samsung turn on trigger for {device_name}",
}
turn_on_trigger = async_get_turn_on_trigger(device_id)
unsubs.append(
PluggableAction.async_attach_trigger(
hass, turn_on_trigger, action, {"trigger": variables}
)
)
@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()
return async_remove

View file

@ -20,10 +20,13 @@ from samsungtvws.exceptions import ResponseError
from samsungtvws.remote import ChannelEmitCommand
from homeassistant.components.samsungtv.const import WEBSOCKET_SSL_PORT
from homeassistant.core import HomeAssistant, ServiceCall
import homeassistant.util.dt as dt_util
from .const import SAMPLE_DEVICE_INFO_UE48JU6400, SAMPLE_DEVICE_INFO_WIFI
from tests.common import async_mock_service
@pytest.fixture
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@ -307,3 +310,9 @@ def mac_address_fixture() -> Mock:
"""Patch getmac.get_mac_address."""
with patch("getmac.get_mac_address", return_value=None) as mac:
yield mac
@pytest.fixture
def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Track calls to a mock service."""
return async_mock_service(hass, "test", "automation")

View file

@ -0,0 +1,149 @@
"""The tests for Samsung TV device triggers."""
from unittest.mock import patch
import pytest
from homeassistant.components import automation
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig,
)
from homeassistant.components.samsungtv import DOMAIN, device_trigger
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.device_registry import async_get as get_dev_reg
from homeassistant.setup import async_setup_component
from . import setup_samsungtv_entry
from .test_media_player import ENTITY_ID, MOCK_ENTRYDATA_ENCRYPTED_WS
from tests.common import MockConfigEntry, async_get_device_automations
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_get_triggers(hass: HomeAssistant) -> None:
"""Test we get the expected triggers."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
device_reg = get_dev_reg(hass)
device = device_reg.async_get_device(identifiers={(DOMAIN, "any")})
turn_on_trigger = {
"platform": "device",
"domain": DOMAIN,
"type": "samsungtv.turn_on",
"device_id": device.id,
"metadata": {},
}
triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device.id
)
assert turn_on_trigger in triggers
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_if_fires_on_turn_on_request(
hass: HomeAssistant, calls: list[ServiceCall]
) -> None:
"""Test for turn_on and turn_off triggers firing."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
device_reg = get_dev_reg(hass)
device = device_reg.async_get_device(identifiers={(DOMAIN, "any")})
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "device",
"domain": DOMAIN,
"device_id": device.id,
"type": "samsungtv.turn_on",
},
"action": {
"service": "test.automation",
"data_template": {
"some": "{{ trigger.device_id }}",
"id": "{{ trigger.id }}",
},
},
},
{
"trigger": {
"platform": "samsungtv.turn_on",
"entity_id": ENTITY_ID,
},
"action": {
"service": "test.automation",
"data_template": {
"some": ENTITY_ID,
"id": "{{ trigger.id }}",
},
},
},
],
},
)
with patch("homeassistant.components.samsungtv.media_player.send_magic_packet"):
await hass.services.async_call(
"media_player",
"turn_on",
{"entity_id": ENTITY_ID},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].data["some"] == device.id
assert calls[0].data["id"] == 0
assert calls[1].data["some"] == ENTITY_ID
assert calls[1].data["id"] == 0
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_failure_scenarios(hass: HomeAssistant) -> None:
"""Test failure scenarios."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
# Test wrong trigger platform type
with pytest.raises(HomeAssistantError):
await device_trigger.async_attach_trigger(
hass, {"type": "wrong.type", "device_id": "invalid_device_id"}, None, {}
)
# Test invalid device id
with pytest.raises(InvalidDeviceAutomationConfig):
await device_trigger.async_validate_trigger_config(
hass,
{
"platform": "device",
"domain": DOMAIN,
"type": "samsungtv.turn_on",
"device_id": "invalid_device_id",
},
)
entry = MockConfigEntry(domain="fake", state=ConfigEntryState.LOADED, data={})
entry.add_to_hass(hass)
device_reg = get_dev_reg(hass)
device = device_reg.async_get_or_create(
config_entry_id=entry.entry_id, identifiers={("fake", "fake")}
)
config = {
"platform": "device",
"domain": DOMAIN,
"device_id": device.id,
"type": "samsungtv.turn_on",
}
# Test that device id from non samsungtv domain raises exception
with pytest.raises(InvalidDeviceAutomationConfig):
await device_trigger.async_validate_trigger_config(hass, config)

View file

@ -0,0 +1,196 @@
"""The tests for WebOS TV automation triggers."""
from unittest.mock import patch
import pytest
from homeassistant.components import automation
from homeassistant.components.samsungtv import DOMAIN
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import device_registry as dr
from homeassistant.setup import async_setup_component
from . import setup_samsungtv_entry
from .test_media_player import ENTITY_ID, MOCK_ENTRYDATA_ENCRYPTED_WS
from tests.common import MockEntity, MockEntityPlatform
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_turn_on_trigger_device_id(
hass: HomeAssistant, calls: list[ServiceCall], device_registry: dr.DeviceRegistry
) -> None:
"""Test for turn_on triggers by device_id firing."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
device = device_registry.async_get_device(identifiers={(DOMAIN, "any")})
assert device, repr(device_registry.devices)
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "samsungtv.turn_on",
"device_id": device.id,
},
"action": {
"service": "test.automation",
"data_template": {
"some": device.id,
"id": "{{ trigger.id }}",
},
},
},
],
},
)
with patch("homeassistant.components.samsungtv.media_player.send_magic_packet"):
await hass.services.async_call(
"media_player",
"turn_on",
{"entity_id": ENTITY_ID},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["some"] == device.id
assert calls[0].data["id"] == 0
with patch("homeassistant.config.load_yaml", return_value={}):
await hass.services.async_call(automation.DOMAIN, SERVICE_RELOAD, blocking=True)
calls.clear()
with patch("homeassistant.components.samsungtv.media_player.send_magic_packet"):
await hass.services.async_call(
"media_player",
"turn_on",
{"entity_id": ENTITY_ID},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 0
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_turn_on_trigger_entity_id(
hass: HomeAssistant, calls: list[ServiceCall]
) -> None:
"""Test for turn_on triggers by entity_id firing."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "samsungtv.turn_on",
"entity_id": ENTITY_ID,
},
"action": {
"service": "test.automation",
"data_template": {
"some": ENTITY_ID,
"id": "{{ trigger.id }}",
},
},
},
],
},
)
with patch("homeassistant.components.samsungtv.media_player.send_magic_packet"):
await hass.services.async_call(
"media_player",
"turn_on",
{"entity_id": ENTITY_ID},
blocking=True,
)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["some"] == ENTITY_ID
assert calls[0].data["id"] == 0
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_wrong_trigger_platform_type(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test wrong trigger platform type."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "samsungtv.wrong_type",
"entity_id": ENTITY_ID,
},
"action": {
"service": "test.automation",
"data_template": {
"some": ENTITY_ID,
"id": "{{ trigger.id }}",
},
},
},
],
},
)
assert (
"ValueError: Unknown Samsung TV trigger platform samsungtv.wrong_type"
in caplog.text
)
@pytest.mark.usefixtures("remoteencws", "rest_api")
async def test_trigger_invalid_entity_id(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test turn on trigger using invalid entity_id."""
await setup_samsungtv_entry(hass, MOCK_ENTRYDATA_ENCRYPTED_WS)
platform = MockEntityPlatform(hass)
invalid_entity = f"{DOMAIN}.invalid"
await platform.async_add_entities([MockEntity(name=invalid_entity)])
await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "samsungtv.turn_on",
"entity_id": invalid_entity,
},
"action": {
"service": "test.automation",
"data_template": {
"some": ENTITY_ID,
"id": "{{ trigger.id }}",
},
},
},
],
},
)
assert (
f"ValueError: Entity {invalid_entity} is not a valid samsungtv entity"
in caplog.text
)