Fix service annotations (#31402)
* Fix service annotations * Filter area_id from service data * Fix services not accepting entities * Typo
This commit is contained in:
parent
81dbdc6b9c
commit
7687ac8b91
5 changed files with 78 additions and 29 deletions
|
@ -143,11 +143,15 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_SELECT_NEXT, {}, lambda entity, call: entity.async_offset_index(1)
|
SERVICE_SELECT_NEXT,
|
||||||
|
{},
|
||||||
|
callback(lambda entity, call: entity.async_offset_index(1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_SELECT_PREVIOUS, {}, lambda entity, call: entity.async_offset_index(-1)
|
SERVICE_SELECT_PREVIOUS,
|
||||||
|
{},
|
||||||
|
callback(lambda entity, call: entity.async_offset_index(-1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
|
@ -248,7 +252,8 @@ class InputSelect(RestoreEntity):
|
||||||
"""Return unique id for the entity."""
|
"""Return unique id for the entity."""
|
||||||
return self._config[CONF_ID]
|
return self._config[CONF_ID]
|
||||||
|
|
||||||
async def async_select_option(self, option):
|
@callback
|
||||||
|
def async_select_option(self, option):
|
||||||
"""Select new option."""
|
"""Select new option."""
|
||||||
if option not in self._options:
|
if option not in self._options:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
|
@ -260,14 +265,16 @@ class InputSelect(RestoreEntity):
|
||||||
self._current_option = option
|
self._current_option = option
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
async def async_offset_index(self, offset):
|
@callback
|
||||||
|
def async_offset_index(self, offset):
|
||||||
"""Offset current index."""
|
"""Offset current index."""
|
||||||
current_index = self._options.index(self._current_option)
|
current_index = self._options.index(self._current_option)
|
||||||
new_index = (current_index + offset) % len(self._options)
|
new_index = (current_index + offset) % len(self._options)
|
||||||
self._current_option = self._options[new_index]
|
self._current_option = self._options[new_index]
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
async def async_set_options(self, options):
|
@callback
|
||||||
|
def async_set_options(self, options):
|
||||||
"""Set options."""
|
"""Set options."""
|
||||||
self._current_option = options[0]
|
self._current_option = options[0]
|
||||||
self._config[CONF_OPTIONS] = options
|
self._config[CONF_OPTIONS] = options
|
||||||
|
|
|
@ -173,6 +173,23 @@ SCHEMA_WEBSOCKET_GET_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.exten
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rename_keys(**keys):
|
||||||
|
"""Create validator that renames keys.
|
||||||
|
|
||||||
|
Necessary because the service schema names do not match the command parameters.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def rename(value):
|
||||||
|
for to_key, from_key in keys.items():
|
||||||
|
if from_key in value:
|
||||||
|
value[to_key] = value.pop(from_key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return rename
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Track states and offer events for media_players."""
|
"""Track states and offer events for media_players."""
|
||||||
component = hass.data[DOMAIN] = EntityComponent(
|
component = hass.data[DOMAIN] = EntityComponent(
|
||||||
|
@ -238,30 +255,39 @@ async def async_setup(hass, config):
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_VOLUME_SET,
|
SERVICE_VOLUME_SET,
|
||||||
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float},
|
vol.All(
|
||||||
lambda entity, call: entity.async_set_volume_level(
|
cv.make_entity_service_schema(
|
||||||
volume=call.data[ATTR_MEDIA_VOLUME_LEVEL]
|
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}
|
||||||
|
),
|
||||||
|
_rename_keys(volume=ATTR_MEDIA_VOLUME_LEVEL),
|
||||||
),
|
),
|
||||||
|
"async_set_volume_level",
|
||||||
[SUPPORT_VOLUME_SET],
|
[SUPPORT_VOLUME_SET],
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_VOLUME_MUTE,
|
SERVICE_VOLUME_MUTE,
|
||||||
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean},
|
vol.All(
|
||||||
lambda entity, call: entity.async_mute_volume(
|
cv.make_entity_service_schema(
|
||||||
mute=call.data[ATTR_MEDIA_VOLUME_MUTED]
|
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}
|
||||||
|
),
|
||||||
|
_rename_keys(mute=ATTR_MEDIA_VOLUME_MUTED),
|
||||||
),
|
),
|
||||||
|
"async_mute_volume",
|
||||||
[SUPPORT_VOLUME_MUTE],
|
[SUPPORT_VOLUME_MUTE],
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_MEDIA_SEEK,
|
SERVICE_MEDIA_SEEK,
|
||||||
{
|
vol.All(
|
||||||
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
|
cv.make_entity_service_schema(
|
||||||
vol.Coerce(float), vol.Range(min=0)
|
{
|
||||||
)
|
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
|
||||||
},
|
vol.Coerce(float), vol.Range(min=0)
|
||||||
lambda entity, call: entity.async_media_seek(
|
)
|
||||||
position=call.data[ATTR_MEDIA_SEEK_POSITION]
|
}
|
||||||
|
),
|
||||||
|
_rename_keys(position=ATTR_MEDIA_SEEK_POSITION),
|
||||||
),
|
),
|
||||||
|
"async_media_seek",
|
||||||
[SUPPORT_SEEK],
|
[SUPPORT_SEEK],
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
|
@ -278,12 +304,15 @@ async def async_setup(hass, config):
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
SERVICE_PLAY_MEDIA,
|
SERVICE_PLAY_MEDIA,
|
||||||
MEDIA_PLAYER_PLAY_MEDIA_SCHEMA,
|
vol.All(
|
||||||
lambda entity, call: entity.async_play_media(
|
cv.make_entity_service_schema(MEDIA_PLAYER_PLAY_MEDIA_SCHEMA),
|
||||||
media_type=call.data[ATTR_MEDIA_CONTENT_TYPE],
|
_rename_keys(
|
||||||
media_id=call.data[ATTR_MEDIA_CONTENT_ID],
|
media_type=ATTR_MEDIA_CONTENT_TYPE,
|
||||||
enqueue=call.data.get(ATTR_MEDIA_ENQUEUE),
|
media_id=ATTR_MEDIA_CONTENT_ID,
|
||||||
|
enqueue=ATTR_MEDIA_ENQUEUE,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
|
"async_play_media",
|
||||||
[SUPPORT_PLAY_MEDIA],
|
[SUPPORT_PLAY_MEDIA],
|
||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
|
|
|
@ -724,6 +724,8 @@ PLATFORM_SCHEMA = vol.Schema(
|
||||||
|
|
||||||
PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
ENTITY_SERVICE_FIELDS = (ATTR_ENTITY_ID, ATTR_AREA_ID)
|
||||||
|
|
||||||
|
|
||||||
def make_entity_service_schema(
|
def make_entity_service_schema(
|
||||||
schema: dict, *, extra: int = vol.PREVENT_EXTRA
|
schema: dict, *, extra: int = vol.PREVENT_EXTRA
|
||||||
|
@ -738,7 +740,7 @@ def make_entity_service_schema(
|
||||||
},
|
},
|
||||||
extra=extra,
|
extra=extra,
|
||||||
),
|
),
|
||||||
has_at_least_one_key(ATTR_ENTITY_ID, ATTR_AREA_ID),
|
has_at_least_one_key(*ENTITY_SERVICE_FIELDS),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -283,7 +283,11 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||||
|
|
||||||
# If the service function is a string, we'll pass it the service call data
|
# If the service function is a string, we'll pass it the service call data
|
||||||
if isinstance(func, str):
|
if isinstance(func, str):
|
||||||
data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID}
|
data = {
|
||||||
|
key: val
|
||||||
|
for key, val in call.data.items()
|
||||||
|
if key not in cv.ENTITY_SERVICE_FIELDS
|
||||||
|
}
|
||||||
# If the service function is not a string, we pass the service call
|
# If the service function is not a string, we pass the service call
|
||||||
else:
|
else:
|
||||||
data = call
|
data = call
|
||||||
|
@ -323,6 +327,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||||
for platform in platforms:
|
for platform in platforms:
|
||||||
platform_entities = []
|
platform_entities = []
|
||||||
for entity in platform.entities.values():
|
for entity in platform.entities.values():
|
||||||
|
|
||||||
if entity.entity_id not in entity_ids:
|
if entity.entity_id not in entity_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -380,7 +385,7 @@ async def _handle_service_platform_call(
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
|
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
|
||||||
func,
|
func,
|
||||||
entity.entity_id,
|
entity.entity_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -320,14 +320,20 @@ async def test_call_with_sync_func(hass, mock_entities):
|
||||||
|
|
||||||
async def test_call_with_sync_attr(hass, mock_entities):
|
async def test_call_with_sync_attr(hass, mock_entities):
|
||||||
"""Test invoking sync service calls."""
|
"""Test invoking sync service calls."""
|
||||||
mock_entities["light.kitchen"].sync_method = Mock()
|
mock_method = mock_entities["light.kitchen"].sync_method = Mock()
|
||||||
await service.entity_service_call(
|
await service.entity_service_call(
|
||||||
hass,
|
hass,
|
||||||
[Mock(entities=mock_entities)],
|
[Mock(entities=mock_entities)],
|
||||||
"sync_method",
|
"sync_method",
|
||||||
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
ha.ServiceCall(
|
||||||
|
"test_domain",
|
||||||
|
"test_service",
|
||||||
|
{"entity_id": "light.kitchen", "area_id": "abcd"},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert mock_entities["light.kitchen"].sync_method.call_count == 1
|
assert mock_method.call_count == 1
|
||||||
|
# We pass empty kwargs because both entity_id and area_id are filtered out
|
||||||
|
assert mock_method.mock_calls[0][2] == {}
|
||||||
|
|
||||||
|
|
||||||
async def test_call_context_user_not_exist(hass):
|
async def test_call_context_user_not_exist(hass):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue