From 3ebb2fc3a9bf78dbe6ffb5542685e7e5ee10b52f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Feb 2023 14:33:27 -0500 Subject: [PATCH] Fix handling of HomeKit sources with unsafe characters (#88280) fixes #87049 --- .../components/homekit/type_media_players.py | 4 +- .../components/homekit/type_remotes.py | 29 ++++---- .../homekit/test_type_media_players.py | 67 +++++++++++++++++++ 3 files changed, 85 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/homekit/type_media_players.py b/homeassistant/components/homekit/type_media_players.py index 55519fdf6f7..eae7ed2742a 100644 --- a/homeassistant/components/homekit/type_media_players.py +++ b/homeassistant/components/homekit/type_media_players.py @@ -305,8 +305,8 @@ class TelevisionMediaPlayer(RemoteInputSelectAccessory): def set_input_source(self, value): """Send input set value if call came from HomeKit.""" _LOGGER.debug("%s: Set current input to %s", self.entity_id, value) - source = self.sources[value] - params = {ATTR_ENTITY_ID: self.entity_id, ATTR_INPUT_SOURCE: source} + source_name = self._mapped_sources[self.sources[value]] + params = {ATTR_ENTITY_ID: self.entity_id, ATTR_INPUT_SOURCE: source_name} self.async_call_service(DOMAIN, SERVICE_SELECT_SOURCE, params) def set_remote_key(self, value): diff --git a/homeassistant/components/homekit/type_remotes.py b/homeassistant/components/homekit/type_remotes.py index 1dfcb0f91a3..69441b5ebe1 100644 --- a/homeassistant/components/homekit/type_remotes.py +++ b/homeassistant/components/homekit/type_remotes.py @@ -91,6 +91,8 @@ class RemoteInputSelectAccessory(HomeAccessory, ABC): state = self.hass.states.get(self.entity_id) features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) + self._mapped_sources_list = [] + self._mapped_sources = {} self.source_key = source_key self.source_list_key = source_list_key self.sources = [] @@ -103,9 +105,7 @@ class RemoteInputSelectAccessory(HomeAccessory, ABC): self.entity_id, MAXIMUM_SOURCES, ) - self.sources = [ - cleanup_name_for_homekit(source) for source in sources[:MAXIMUM_SOURCES] - ] + self.sources = sources[:MAXIMUM_SOURCES] if self.sources: self.support_select_source = True @@ -143,6 +143,15 @@ class RemoteInputSelectAccessory(HomeAccessory, ABC): serv_input.configure_char(CHAR_CURRENT_VISIBILITY_STATE, value=False) _LOGGER.debug("%s: Added source %s", self.entity_id, source) + def _get_mapped_sources(self, state: State) -> dict[str, str]: + """Return a dict of sources mapped to their homekit safe name.""" + source_list = state.attributes.get(self.source_list_key, []) + if self._mapped_sources_list != source_list: + self._mapped_sources = { + cleanup_name_for_homekit(source): source for source in source_list + } + return self._mapped_sources + def _get_ordered_source_list_from_state(self, state: State) -> list[str]: """Return ordered source list while preserving order with duplicates removed. @@ -150,13 +159,7 @@ class RemoteInputSelectAccessory(HomeAccessory, ABC): which will make the source list conflict as HomeKit requires unique source names. """ - seen = set() - sources: list[str] = [] - for source in state.attributes.get(self.source_list_key, []): - if source not in seen: - sources.append(source) - seen.add(source) - return sources + return list(self._get_mapped_sources(state)) @abstractmethod def set_on_off(self, value): @@ -185,8 +188,8 @@ class RemoteInputSelectAccessory(HomeAccessory, ABC): return possible_sources = self._get_ordered_source_list_from_state(new_state) - if source in possible_sources: - index = possible_sources.index(source) + if source_name in possible_sources: + index = possible_sources.index(source_name) if index >= MAXIMUM_SOURCES: _LOGGER.debug( "%s: Source %s and above are not supported", @@ -235,7 +238,7 @@ class ActivityRemote(RemoteInputSelectAccessory): def set_input_source(self, value): """Send input set value if call came from HomeKit.""" _LOGGER.debug("%s: Set current input to %s", self.entity_id, value) - source = self.sources[value] + source = self._mapped_sources[self.sources[value]] params = {ATTR_ENTITY_ID: self.entity_id, ATTR_ACTIVITY: source} self.async_call_service(REMOTE_DOMAIN, SERVICE_TURN_ON, params) diff --git a/tests/components/homekit/test_type_media_players.py b/tests/components/homekit/test_type_media_players.py index b9a2f829801..f68adc24077 100644 --- a/tests/components/homekit/test_type_media_players.py +++ b/tests/components/homekit/test_type_media_players.py @@ -562,3 +562,70 @@ async def test_media_player_television_duplicate_sources( ) await hass.async_block_till_done() assert acc.char_input_source.value == 0 + + +async def test_media_player_television_unsafe_chars( + hass: HomeAssistant, hk_driver, events, caplog: pytest.LogCaptureFixture +) -> None: + """Test if television accessory with unsafe characters.""" + entity_id = "media_player.television" + sources = ["MUSIC", "HDMI 3/ARC", "SCREEN MIRRORING", "HDMI 2/MHL", "HDMI", "MUSIC"] + hass.states.async_set( + entity_id, + None, + { + ATTR_DEVICE_CLASS: MediaPlayerDeviceClass.TV, + ATTR_SUPPORTED_FEATURES: 3469, + ATTR_MEDIA_VOLUME_MUTED: False, + ATTR_INPUT_SOURCE: "HDMI 2/MHL", + ATTR_INPUT_SOURCE_LIST: sources, + }, + ) + await hass.async_block_till_done() + acc = TelevisionMediaPlayer(hass, hk_driver, "MediaPlayer", entity_id, 2, None) + await acc.run() + await hass.async_block_till_done() + + assert acc.aid == 2 + assert acc.category == 31 # Television + + assert acc.char_active.value == 0 + assert acc.char_remote_key.value == 0 + assert acc.char_input_source.value == 3 + assert acc.char_mute.value is False + + hass.states.async_set( + entity_id, + None, + { + ATTR_DEVICE_CLASS: MediaPlayerDeviceClass.TV, + ATTR_SUPPORTED_FEATURES: 3469, + ATTR_MEDIA_VOLUME_MUTED: False, + ATTR_INPUT_SOURCE: "HDMI 3/ARC", + ATTR_INPUT_SOURCE_LIST: sources, + }, + ) + await hass.async_block_till_done() + assert acc.char_input_source.value == 1 + + call_select_source = async_mock_service(hass, DOMAIN, "select_source") + + acc.char_input_source.client_update_value(3) + await hass.async_block_till_done() + assert call_select_source + assert call_select_source[0].data[ATTR_ENTITY_ID] == entity_id + assert call_select_source[0].data[ATTR_INPUT_SOURCE] == "HDMI 2/MHL" + assert len(events) == 1 + assert events[-1].data[ATTR_VALUE] is None + + assert acc.char_input_source.value == 3 + + acc.char_input_source.client_update_value(4) + await hass.async_block_till_done() + assert call_select_source + assert call_select_source[1].data[ATTR_ENTITY_ID] == entity_id + assert call_select_source[1].data[ATTR_INPUT_SOURCE] == "HDMI" + assert len(events) == 2 + assert events[-1].data[ATTR_VALUE] is None + + assert acc.char_input_source.value == 4