Fix service annotations (#31402)

* Fix service annotations

* Filter area_id from service data

* Fix services not accepting entities

* Typo
This commit is contained in:
Paulus Schoutsen 2020-02-02 15:36:39 -08:00 committed by GitHub
parent 81dbdc6b9c
commit 7687ac8b91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 29 deletions

View file

@ -143,11 +143,15 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_SELECT_NEXT, {}, lambda entity, call: entity.async_offset_index(1) SERVICE_SELECT_NEXT,
{},
callback(lambda entity, call: entity.async_offset_index(1)),
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_SELECT_PREVIOUS, {}, lambda entity, call: entity.async_offset_index(-1) SERVICE_SELECT_PREVIOUS,
{},
callback(lambda entity, call: entity.async_offset_index(-1)),
) )
component.async_register_entity_service( component.async_register_entity_service(
@ -248,7 +252,8 @@ class InputSelect(RestoreEntity):
"""Return unique id for the entity.""" """Return unique id for the entity."""
return self._config[CONF_ID] return self._config[CONF_ID]
async def async_select_option(self, option): @callback
def async_select_option(self, option):
"""Select new option.""" """Select new option."""
if option not in self._options: if option not in self._options:
_LOGGER.warning( _LOGGER.warning(
@ -260,14 +265,16 @@ class InputSelect(RestoreEntity):
self._current_option = option self._current_option = option
self.async_write_ha_state() self.async_write_ha_state()
async def async_offset_index(self, offset): @callback
def async_offset_index(self, offset):
"""Offset current index.""" """Offset current index."""
current_index = self._options.index(self._current_option) current_index = self._options.index(self._current_option)
new_index = (current_index + offset) % len(self._options) new_index = (current_index + offset) % len(self._options)
self._current_option = self._options[new_index] self._current_option = self._options[new_index]
self.async_write_ha_state() self.async_write_ha_state()
async def async_set_options(self, options): @callback
def async_set_options(self, options):
"""Set options.""" """Set options."""
self._current_option = options[0] self._current_option = options[0]
self._config[CONF_OPTIONS] = options self._config[CONF_OPTIONS] = options

View file

@ -173,6 +173,23 @@ SCHEMA_WEBSOCKET_GET_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.exten
) )
def _rename_keys(**keys):
"""Create validator that renames keys.
Necessary because the service schema names do not match the command parameters.
Async friendly.
"""
def rename(value):
for to_key, from_key in keys.items():
if from_key in value:
value[to_key] = value.pop(from_key)
return value
return rename
async def async_setup(hass, config): async def async_setup(hass, config):
"""Track states and offer events for media_players.""" """Track states and offer events for media_players."""
component = hass.data[DOMAIN] = EntityComponent( component = hass.data[DOMAIN] = EntityComponent(
@ -238,30 +255,39 @@ async def async_setup(hass, config):
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_VOLUME_SET, SERVICE_VOLUME_SET,
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}, vol.All(
lambda entity, call: entity.async_set_volume_level( cv.make_entity_service_schema(
volume=call.data[ATTR_MEDIA_VOLUME_LEVEL] {vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}
),
_rename_keys(volume=ATTR_MEDIA_VOLUME_LEVEL),
), ),
"async_set_volume_level",
[SUPPORT_VOLUME_SET], [SUPPORT_VOLUME_SET],
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_VOLUME_MUTE, SERVICE_VOLUME_MUTE,
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}, vol.All(
lambda entity, call: entity.async_mute_volume( cv.make_entity_service_schema(
mute=call.data[ATTR_MEDIA_VOLUME_MUTED] {vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}
),
_rename_keys(mute=ATTR_MEDIA_VOLUME_MUTED),
), ),
"async_mute_volume",
[SUPPORT_VOLUME_MUTE], [SUPPORT_VOLUME_MUTE],
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_MEDIA_SEEK, SERVICE_MEDIA_SEEK,
{ vol.All(
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All( cv.make_entity_service_schema(
vol.Coerce(float), vol.Range(min=0) {
) vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
}, vol.Coerce(float), vol.Range(min=0)
lambda entity, call: entity.async_media_seek( )
position=call.data[ATTR_MEDIA_SEEK_POSITION] }
),
_rename_keys(position=ATTR_MEDIA_SEEK_POSITION),
), ),
"async_media_seek",
[SUPPORT_SEEK], [SUPPORT_SEEK],
) )
component.async_register_entity_service( component.async_register_entity_service(
@ -278,12 +304,15 @@ async def async_setup(hass, config):
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_PLAY_MEDIA, SERVICE_PLAY_MEDIA,
MEDIA_PLAYER_PLAY_MEDIA_SCHEMA, vol.All(
lambda entity, call: entity.async_play_media( cv.make_entity_service_schema(MEDIA_PLAYER_PLAY_MEDIA_SCHEMA),
media_type=call.data[ATTR_MEDIA_CONTENT_TYPE], _rename_keys(
media_id=call.data[ATTR_MEDIA_CONTENT_ID], media_type=ATTR_MEDIA_CONTENT_TYPE,
enqueue=call.data.get(ATTR_MEDIA_ENQUEUE), media_id=ATTR_MEDIA_CONTENT_ID,
enqueue=ATTR_MEDIA_ENQUEUE,
),
), ),
"async_play_media",
[SUPPORT_PLAY_MEDIA], [SUPPORT_PLAY_MEDIA],
) )
component.async_register_entity_service( component.async_register_entity_service(

View file

@ -724,6 +724,8 @@ PLATFORM_SCHEMA = vol.Schema(
PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
ENTITY_SERVICE_FIELDS = (ATTR_ENTITY_ID, ATTR_AREA_ID)
def make_entity_service_schema( def make_entity_service_schema(
schema: dict, *, extra: int = vol.PREVENT_EXTRA schema: dict, *, extra: int = vol.PREVENT_EXTRA
@ -738,7 +740,7 @@ def make_entity_service_schema(
}, },
extra=extra, extra=extra,
), ),
has_at_least_one_key(ATTR_ENTITY_ID, ATTR_AREA_ID), has_at_least_one_key(*ENTITY_SERVICE_FIELDS),
) )

View file

@ -283,7 +283,11 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# If the service function is a string, we'll pass it the service call data # If the service function is a string, we'll pass it the service call data
if isinstance(func, str): if isinstance(func, str):
data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID} data = {
key: val
for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS
}
# If the service function is not a string, we pass the service call # If the service function is not a string, we pass the service call
else: else:
data = call data = call
@ -323,6 +327,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
for platform in platforms: for platform in platforms:
platform_entities = [] platform_entities = []
for entity in platform.entities.values(): for entity in platform.entities.values():
if entity.entity_id not in entity_ids: if entity.entity_id not in entity_ids:
continue continue
@ -380,7 +385,7 @@ async def _handle_service_platform_call(
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
_LOGGER.error( _LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.", "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func, func,
entity.entity_id, entity.entity_id,
) )

View file

@ -320,14 +320,20 @@ async def test_call_with_sync_func(hass, mock_entities):
async def test_call_with_sync_attr(hass, mock_entities): async def test_call_with_sync_attr(hass, mock_entities):
"""Test invoking sync service calls.""" """Test invoking sync service calls."""
mock_entities["light.kitchen"].sync_method = Mock() mock_method = mock_entities["light.kitchen"].sync_method = Mock()
await service.entity_service_call( await service.entity_service_call(
hass, hass,
[Mock(entities=mock_entities)], [Mock(entities=mock_entities)],
"sync_method", "sync_method",
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), ha.ServiceCall(
"test_domain",
"test_service",
{"entity_id": "light.kitchen", "area_id": "abcd"},
),
) )
assert mock_entities["light.kitchen"].sync_method.call_count == 1 assert mock_method.call_count == 1
# We pass empty kwargs because both entity_id and area_id are filtered out
assert mock_method.mock_calls[0][2] == {}
async def test_call_context_user_not_exist(hass): async def test_call_context_user_not_exist(hass):