Fix service annotations (#31402)
* Fix service annotations * Filter area_id from service data * Fix services not accepting entities * Typo
This commit is contained in:
parent
81dbdc6b9c
commit
7687ac8b91
5 changed files with 78 additions and 29 deletions
|
@ -143,11 +143,15 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
|||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
SERVICE_SELECT_NEXT, {}, lambda entity, call: entity.async_offset_index(1)
|
||||
SERVICE_SELECT_NEXT,
|
||||
{},
|
||||
callback(lambda entity, call: entity.async_offset_index(1)),
|
||||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
SERVICE_SELECT_PREVIOUS, {}, lambda entity, call: entity.async_offset_index(-1)
|
||||
SERVICE_SELECT_PREVIOUS,
|
||||
{},
|
||||
callback(lambda entity, call: entity.async_offset_index(-1)),
|
||||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
|
@ -248,7 +252,8 @@ class InputSelect(RestoreEntity):
|
|||
"""Return unique id for the entity."""
|
||||
return self._config[CONF_ID]
|
||||
|
||||
async def async_select_option(self, option):
|
||||
@callback
|
||||
def async_select_option(self, option):
|
||||
"""Select new option."""
|
||||
if option not in self._options:
|
||||
_LOGGER.warning(
|
||||
|
@ -260,14 +265,16 @@ class InputSelect(RestoreEntity):
|
|||
self._current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_offset_index(self, offset):
|
||||
@callback
|
||||
def async_offset_index(self, offset):
|
||||
"""Offset current index."""
|
||||
current_index = self._options.index(self._current_option)
|
||||
new_index = (current_index + offset) % len(self._options)
|
||||
self._current_option = self._options[new_index]
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_set_options(self, options):
|
||||
@callback
|
||||
def async_set_options(self, options):
|
||||
"""Set options."""
|
||||
self._current_option = options[0]
|
||||
self._config[CONF_OPTIONS] = options
|
||||
|
|
|
@ -173,6 +173,23 @@ SCHEMA_WEBSOCKET_GET_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.exten
|
|||
)
|
||||
|
||||
|
||||
def _rename_keys(**keys):
|
||||
"""Create validator that renames keys.
|
||||
|
||||
Necessary because the service schema names do not match the command parameters.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
|
||||
def rename(value):
|
||||
for to_key, from_key in keys.items():
|
||||
if from_key in value:
|
||||
value[to_key] = value.pop(from_key)
|
||||
return value
|
||||
|
||||
return rename
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Track states and offer events for media_players."""
|
||||
component = hass.data[DOMAIN] = EntityComponent(
|
||||
|
@ -238,30 +255,39 @@ async def async_setup(hass, config):
|
|||
)
|
||||
component.async_register_entity_service(
|
||||
SERVICE_VOLUME_SET,
|
||||
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float},
|
||||
lambda entity, call: entity.async_set_volume_level(
|
||||
volume=call.data[ATTR_MEDIA_VOLUME_LEVEL]
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}
|
||||
),
|
||||
_rename_keys(volume=ATTR_MEDIA_VOLUME_LEVEL),
|
||||
),
|
||||
"async_set_volume_level",
|
||||
[SUPPORT_VOLUME_SET],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
SERVICE_VOLUME_MUTE,
|
||||
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean},
|
||||
lambda entity, call: entity.async_mute_volume(
|
||||
mute=call.data[ATTR_MEDIA_VOLUME_MUTED]
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}
|
||||
),
|
||||
_rename_keys(mute=ATTR_MEDIA_VOLUME_MUTED),
|
||||
),
|
||||
"async_mute_volume",
|
||||
[SUPPORT_VOLUME_MUTE],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
SERVICE_MEDIA_SEEK,
|
||||
{
|
||||
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
|
||||
vol.Coerce(float), vol.Range(min=0)
|
||||
)
|
||||
},
|
||||
lambda entity, call: entity.async_media_seek(
|
||||
position=call.data[ATTR_MEDIA_SEEK_POSITION]
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
|
||||
vol.Coerce(float), vol.Range(min=0)
|
||||
)
|
||||
}
|
||||
),
|
||||
_rename_keys(position=ATTR_MEDIA_SEEK_POSITION),
|
||||
),
|
||||
"async_media_seek",
|
||||
[SUPPORT_SEEK],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
|
@ -278,12 +304,15 @@ async def async_setup(hass, config):
|
|||
)
|
||||
component.async_register_entity_service(
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MEDIA_PLAYER_PLAY_MEDIA_SCHEMA,
|
||||
lambda entity, call: entity.async_play_media(
|
||||
media_type=call.data[ATTR_MEDIA_CONTENT_TYPE],
|
||||
media_id=call.data[ATTR_MEDIA_CONTENT_ID],
|
||||
enqueue=call.data.get(ATTR_MEDIA_ENQUEUE),
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(MEDIA_PLAYER_PLAY_MEDIA_SCHEMA),
|
||||
_rename_keys(
|
||||
media_type=ATTR_MEDIA_CONTENT_TYPE,
|
||||
media_id=ATTR_MEDIA_CONTENT_ID,
|
||||
enqueue=ATTR_MEDIA_ENQUEUE,
|
||||
),
|
||||
),
|
||||
"async_play_media",
|
||||
[SUPPORT_PLAY_MEDIA],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
|
|
|
@ -724,6 +724,8 @@ PLATFORM_SCHEMA = vol.Schema(
|
|||
|
||||
PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
ENTITY_SERVICE_FIELDS = (ATTR_ENTITY_ID, ATTR_AREA_ID)
|
||||
|
||||
|
||||
def make_entity_service_schema(
|
||||
schema: dict, *, extra: int = vol.PREVENT_EXTRA
|
||||
|
@ -738,7 +740,7 @@ def make_entity_service_schema(
|
|||
},
|
||||
extra=extra,
|
||||
),
|
||||
has_at_least_one_key(ATTR_ENTITY_ID, ATTR_AREA_ID),
|
||||
has_at_least_one_key(*ENTITY_SERVICE_FIELDS),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -283,7 +283,11 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
|||
|
||||
# If the service function is a string, we'll pass it the service call data
|
||||
if isinstance(func, str):
|
||||
data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID}
|
||||
data = {
|
||||
key: val
|
||||
for key, val in call.data.items()
|
||||
if key not in cv.ENTITY_SERVICE_FIELDS
|
||||
}
|
||||
# If the service function is not a string, we pass the service call
|
||||
else:
|
||||
data = call
|
||||
|
@ -323,6 +327,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
|||
for platform in platforms:
|
||||
platform_entities = []
|
||||
for entity in platform.entities.values():
|
||||
|
||||
if entity.entity_id not in entity_ids:
|
||||
continue
|
||||
|
||||
|
@ -380,7 +385,7 @@ async def _handle_service_platform_call(
|
|||
|
||||
if asyncio.iscoroutine(result):
|
||||
_LOGGER.error(
|
||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
|
||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
|
||||
func,
|
||||
entity.entity_id,
|
||||
)
|
||||
|
|
|
@ -320,14 +320,20 @@ async def test_call_with_sync_func(hass, mock_entities):
|
|||
|
||||
async def test_call_with_sync_attr(hass, mock_entities):
|
||||
"""Test invoking sync service calls."""
|
||||
mock_entities["light.kitchen"].sync_method = Mock()
|
||||
mock_method = mock_entities["light.kitchen"].sync_method = Mock()
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
"sync_method",
|
||||
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
||||
ha.ServiceCall(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{"entity_id": "light.kitchen", "area_id": "abcd"},
|
||||
),
|
||||
)
|
||||
assert mock_entities["light.kitchen"].sync_method.call_count == 1
|
||||
assert mock_method.call_count == 1
|
||||
# We pass empty kwargs because both entity_id and area_id are filtered out
|
||||
assert mock_method.mock_calls[0][2] == {}
|
||||
|
||||
|
||||
async def test_call_context_user_not_exist(hass):
|
||||
|
|
Loading…
Add table
Reference in a new issue