diff --git a/homeassistant/components/apple_tv/remote.py b/homeassistant/components/apple_tv/remote.py index f3be6977891..bab3421c58d 100644 --- a/homeassistant/components/apple_tv/remote.py +++ b/homeassistant/components/apple_tv/remote.py @@ -21,6 +21,15 @@ from .const import DOMAIN _LOGGER = logging.getLogger(__name__) PARALLEL_UPDATES = 0 +COMMAND_TO_ATTRIBUTE = { + "wakeup": ("power", "turn_on"), + "suspend": ("power", "turn_off"), + "turn_on": ("power", "turn_on"), + "turn_off": ("power", "turn_off"), + "volume_up": ("audio", "volume_up"), + "volume_down": ("audio", "volume_down"), + "home_hold": ("remote_control", "home"), +} async def async_setup_entry( @@ -61,7 +70,13 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity): for _ in range(num_repeats): for single_command in command: - attr_value = getattr(self.atv.remote_control, single_command, None) + attr_value = None + if attributes := COMMAND_TO_ATTRIBUTE.get(single_command): + attr_value = self.atv + for attr_name in attributes: + attr_value = getattr(attr_value, attr_name, None) + if not attr_value: + attr_value = getattr(self.atv.remote_control, single_command, None) if not attr_value: raise ValueError("Command not found. Exiting sequence") diff --git a/tests/components/apple_tv/test_remote.py b/tests/components/apple_tv/test_remote.py new file mode 100644 index 00000000000..db2a4964f6c --- /dev/null +++ b/tests/components/apple_tv/test_remote.py @@ -0,0 +1,28 @@ +"""Test apple_tv remote.""" +from unittest.mock import AsyncMock + +import pytest + +from homeassistant.components.apple_tv.remote import AppleTVRemote +from homeassistant.components.remote import ATTR_DELAY_SECS, ATTR_NUM_REPEATS + + +@pytest.mark.parametrize( + ("command", "method"), + [ + ("up", "remote_control.up"), + ("wakeup", "power.turn_on"), + ("volume_up", "audio.volume_up"), + ("home_hold", "remote_control.home"), + ], + ids=["up", "wakeup", "volume_up", "home_hold"], +) +async def test_send_command(command: str, method: str) -> None: + """Test "send_command" method.""" + remote = AppleTVRemote("test", "test", None) + remote.atv = AsyncMock() + await remote.async_send_command( + [command], **{ATTR_NUM_REPEATS: 1, ATTR_DELAY_SECS: 0} + ) + assert len(remote.atv.method_calls) == 1 + assert str(remote.atv.method_calls[0]) == f"call.{method}()"