Add foundation for integration services (#30813)

* Add foundation for integration services

* Fix tests

* Remove async_get_platform

* Migrate Sonos partially to EntityPlatform.async_register_entity_service

* Tweaks

* Move other Sonos services to media player domain

* Move other Sonos services to media player domain

* Address comments

* Remove lock

* Fix typos

* Use make_entity_service_schema

* Add area extraction to async_extract_entities

Co-authored-by: Anders Melchiorsen <amelchio@nogoto.net>
This commit is contained in:
Paulus Schoutsen 2020-01-19 17:55:18 -08:00 committed by GitHub
parent f20b3515f2
commit 0c3ffbe282
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 339 additions and 289 deletions

View file

@ -1,13 +1,10 @@
"""Support to embed Sonos.""" """Support to embed Sonos."""
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.media_player import DOMAIN as MP_DOMAIN from homeassistant.components.media_player import DOMAIN as MP_DOMAIN
from homeassistant.const import ATTR_ENTITY_ID, ATTR_TIME, CONF_HOSTS from homeassistant.const import CONF_HOSTS
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from .const import DOMAIN from .const import DOMAIN
@ -33,91 +30,12 @@ CONFIG_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
SERVICE_JOIN = "join"
SERVICE_UNJOIN = "unjoin"
SERVICE_SNAPSHOT = "snapshot"
SERVICE_RESTORE = "restore"
SERVICE_SET_TIMER = "set_sleep_timer"
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
SERVICE_UPDATE_ALARM = "update_alarm"
SERVICE_SET_OPTION = "set_option"
SERVICE_PLAY_QUEUE = "play_queue"
ATTR_SLEEP_TIME = "sleep_time"
ATTR_ALARM_ID = "alarm_id"
ATTR_VOLUME = "volume"
ATTR_ENABLED = "enabled"
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
ATTR_MASTER = "master"
ATTR_WITH_GROUP = "with_group"
ATTR_NIGHT_SOUND = "night_sound"
ATTR_SPEECH_ENHANCE = "speech_enhance"
ATTR_QUEUE_POSITION = "queue_position"
SONOS_JOIN_SCHEMA = vol.Schema(
{
vol.Required(ATTR_MASTER): cv.entity_id,
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
}
)
SONOS_UNJOIN_SCHEMA = vol.Schema({vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids})
SONOS_STATES_SCHEMA = vol.Schema(
{
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean,
}
)
SONOS_SET_TIMER_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_SLEEP_TIME): vol.All(
vol.Coerce(int), vol.Range(min=0, max=86399)
),
}
)
SONOS_CLEAR_TIMER_SCHEMA = vol.Schema(
{vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids}
)
SONOS_UPDATE_ALARM_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_ALARM_ID): cv.positive_int,
vol.Optional(ATTR_TIME): cv.time,
vol.Optional(ATTR_VOLUME): cv.small_float,
vol.Optional(ATTR_ENABLED): cv.boolean,
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
}
)
SONOS_SET_OPTION_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
}
)
SONOS_PLAY_QUEUE_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_QUEUE_POSITION, default=0): cv.positive_int,
}
)
DATA_SERVICE_EVENT = "sonos_service_idle"
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the Sonos component.""" """Set up the Sonos component."""
conf = config.get(DOMAIN) conf = config.get(DOMAIN)
hass.data[DOMAIN] = conf or {} hass.data[DOMAIN] = conf or {}
hass.data[DATA_SERVICE_EVENT] = asyncio.Event()
if conf is not None: if conf is not None:
hass.async_create_task( hass.async_create_task(
@ -126,48 +44,6 @@ async def async_setup(hass, config):
) )
) )
async def service_handle(service):
"""Dispatch a service call."""
hass.data[DATA_SERVICE_EVENT].clear()
async_dispatcher_send(hass, DOMAIN, service.service, service.data)
await hass.data[DATA_SERVICE_EVENT].wait()
hass.services.async_register(
DOMAIN, SERVICE_JOIN, service_handle, schema=SONOS_JOIN_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_UNJOIN, service_handle, schema=SONOS_UNJOIN_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SNAPSHOT, service_handle, schema=SONOS_STATES_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_RESTORE, service_handle, schema=SONOS_STATES_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SET_TIMER, service_handle, schema=SONOS_SET_TIMER_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_CLEAR_TIMER, service_handle, schema=SONOS_CLEAR_TIMER_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_UPDATE_ALARM, service_handle, schema=SONOS_UPDATE_ALARM_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SET_OPTION, service_handle, schema=SONOS_SET_OPTION_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_PLAY_QUEUE, service_handle, schema=SONOS_PLAY_QUEUE_SCHEMA
)
return True return True

View file

@ -11,6 +11,7 @@ import pysonos
from pysonos import alarms from pysonos import alarms
from pysonos.exceptions import SoCoException, SoCoUPnPException from pysonos.exceptions import SoCoException, SoCoUPnPException
import pysonos.snapshot import pysonos.snapshot
import voluptuous as vol
from homeassistant.components.media_player import MediaPlayerDevice from homeassistant.components.media_player import MediaPlayerDevice
from homeassistant.components.media_player.const import ( from homeassistant.components.media_player.const import (
@ -30,42 +31,16 @@ from homeassistant.components.media_player.const import (
SUPPORT_VOLUME_MUTE, SUPPORT_VOLUME_MUTE,
SUPPORT_VOLUME_SET, SUPPORT_VOLUME_SET,
) )
from homeassistant.const import ( from homeassistant.const import ATTR_TIME, STATE_IDLE, STATE_PAUSED, STATE_PLAYING
ENTITY_MATCH_ALL, from homeassistant.core import ServiceCall, callback
STATE_IDLE, from homeassistant.helpers import config_validation as cv, entity_platform, service
STATE_PAUSED,
STATE_PLAYING,
)
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import ( from . import (
ATTR_ALARM_ID,
ATTR_ENABLED,
ATTR_INCLUDE_LINKED_ZONES,
ATTR_MASTER,
ATTR_NIGHT_SOUND,
ATTR_QUEUE_POSITION,
ATTR_SLEEP_TIME,
ATTR_SPEECH_ENHANCE,
ATTR_TIME,
ATTR_VOLUME,
ATTR_WITH_GROUP,
CONF_ADVERTISE_ADDR, CONF_ADVERTISE_ADDR,
CONF_HOSTS, CONF_HOSTS,
CONF_INTERFACE_ADDR, CONF_INTERFACE_ADDR,
DATA_SERVICE_EVENT,
DOMAIN as SONOS_DOMAIN, DOMAIN as SONOS_DOMAIN,
SERVICE_CLEAR_TIMER,
SERVICE_JOIN,
SERVICE_PLAY_QUEUE,
SERVICE_RESTORE,
SERVICE_SET_OPTION,
SERVICE_SET_TIMER,
SERVICE_SNAPSHOT,
SERVICE_UNJOIN,
SERVICE_UPDATE_ALARM,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -97,6 +72,27 @@ ATTR_SONOS_GROUP = "sonos_group"
UPNP_ERRORS_TO_IGNORE = ["701", "711", "712"] UPNP_ERRORS_TO_IGNORE = ["701", "711", "712"]
SERVICE_JOIN = "join"
SERVICE_UNJOIN = "unjoin"
SERVICE_SNAPSHOT = "snapshot"
SERVICE_RESTORE = "restore"
SERVICE_SET_TIMER = "set_sleep_timer"
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
SERVICE_UPDATE_ALARM = "update_alarm"
SERVICE_SET_OPTION = "set_option"
SERVICE_PLAY_QUEUE = "play_queue"
ATTR_SLEEP_TIME = "sleep_time"
ATTR_ALARM_ID = "alarm_id"
ATTR_VOLUME = "volume"
ATTR_ENABLED = "enabled"
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
ATTR_MASTER = "master"
ATTR_WITH_GROUP = "with_group"
ATTR_NIGHT_SOUND = "night_sound"
ATTR_SPEECH_ENHANCE = "speech_enhance"
ATTR_QUEUE_POSITION = "queue_position"
class SonosData: class SonosData:
"""Storage class for platform global data.""" """Storage class for platform global data."""
@ -176,46 +172,101 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
_LOGGER.debug("Adding discovery job") _LOGGER.debug("Adding discovery job")
hass.async_add_executor_job(_discovery) hass.async_add_executor_job(_discovery)
async def async_service_handle(service, data): platform = entity_platform.current_platform.get()
async def async_service_handle(service_call: ServiceCall):
"""Handle dispatched services.""" """Handle dispatched services."""
entity_ids = data.get("entity_id") entities = await platform.async_extract_from_service(service_call)
entities = hass.data[DATA_SONOS].entities
if entity_ids and entity_ids != ENTITY_MATCH_ALL:
entities = [e for e in entities if e.entity_id in entity_ids]
if service == SERVICE_JOIN: if not entities:
master = [ return
e
for e in hass.data[DATA_SONOS].entities if service_call.service == SERVICE_JOIN:
if e.entity_id == data[ATTR_MASTER] master = platform.entities.get(service_call.data[ATTR_MASTER])
]
if master: if master:
await SonosEntity.join_multi(hass, master[0], entities) await SonosEntity.join_multi(hass, master, entities)
elif service == SERVICE_UNJOIN: else:
_LOGGER.error(
"Invalid master specified for join service: %s",
service_call.data[ATTR_MASTER],
)
elif service_call.service == SERVICE_UNJOIN:
await SonosEntity.unjoin_multi(hass, entities) await SonosEntity.unjoin_multi(hass, entities)
elif service == SERVICE_SNAPSHOT: elif service_call.service == SERVICE_SNAPSHOT:
await SonosEntity.snapshot_multi(hass, entities, data[ATTR_WITH_GROUP]) await SonosEntity.snapshot_multi(
elif service == SERVICE_RESTORE: hass, entities, service_call.data[ATTR_WITH_GROUP]
await SonosEntity.restore_multi(hass, entities, data[ATTR_WITH_GROUP]) )
else: elif service_call.service == SERVICE_RESTORE:
for entity in entities: await SonosEntity.restore_multi(
if service == SERVICE_SET_TIMER: hass, entities, service_call.data[ATTR_WITH_GROUP]
call = entity.set_sleep_timer )
elif service == SERVICE_CLEAR_TIMER:
call = entity.clear_sleep_timer
elif service == SERVICE_UPDATE_ALARM:
call = entity.set_alarm
elif service == SERVICE_SET_OPTION:
call = entity.set_option
elif service == SERVICE_PLAY_QUEUE:
call = entity.play_queue
hass.async_add_executor_job(call, data) service.async_register_admin_service(
hass,
SONOS_DOMAIN,
SERVICE_JOIN,
async_service_handle,
cv.make_entity_service_schema({vol.Required(ATTR_MASTER): cv.entity_id}),
)
# We are ready for the next service call service.async_register_admin_service(
hass.data[DATA_SERVICE_EVENT].set() hass,
SONOS_DOMAIN,
SERVICE_UNJOIN,
async_service_handle,
cv.make_entity_service_schema({}),
)
async_dispatcher_connect(hass, SONOS_DOMAIN, async_service_handle) join_unjoin_schema = cv.make_entity_service_schema(
{vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean}
)
service.async_register_admin_service(
hass, SONOS_DOMAIN, SERVICE_SNAPSHOT, async_service_handle, join_unjoin_schema
)
service.async_register_admin_service(
hass, SONOS_DOMAIN, SERVICE_RESTORE, async_service_handle, join_unjoin_schema
)
platform.async_register_entity_service(
SERVICE_SET_TIMER,
{
vol.Required(ATTR_SLEEP_TIME): vol.All(
vol.Coerce(int), vol.Range(min=0, max=86399)
)
},
"set_sleep_timer",
)
platform.async_register_entity_service(SERVICE_CLEAR_TIMER, {}, "clear_sleep_timer")
platform.async_register_entity_service(
SERVICE_UPDATE_ALARM,
{
vol.Required(ATTR_ALARM_ID): cv.positive_int,
vol.Optional(ATTR_TIME): cv.time,
vol.Optional(ATTR_VOLUME): cv.small_float,
vol.Optional(ATTR_ENABLED): cv.boolean,
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
},
"set_alarm",
)
platform.async_register_entity_service(
SERVICE_SET_OPTION,
{
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
},
"set_option",
)
platform.async_register_entity_service(
SERVICE_PLAY_QUEUE,
{vol.Optional(ATTR_QUEUE_POSITION): cv.positive_int},
"play_queue",
)
class _ProcessSonosEventQueue: class _ProcessSonosEventQueue:
@ -480,10 +531,10 @@ class SonosEntity(MediaPlayerDevice):
player = self.soco player = self.soco
def subscribe(service, action): def subscribe(sonos_service, action):
"""Add a subscription to a pysonos service.""" """Add a subscription to a pysonos service."""
queue = _ProcessSonosEventQueue(action) queue = _ProcessSonosEventQueue(action)
sub = service.subscribe(auto_renew=True, event_queue=queue) sub = sonos_service.subscribe(auto_renew=True, event_queue=queue)
self._subscriptions.append(sub) self._subscriptions.append(sub)
subscribe(player.avTransport, self.update_media) subscribe(player.avTransport, self.update_media)
@ -1147,52 +1198,53 @@ class SonosEntity(MediaPlayerDevice):
@soco_error() @soco_error()
@soco_coordinator @soco_coordinator
def set_sleep_timer(self, data): def set_sleep_timer(self, sleep_time):
"""Set the timer on the player.""" """Set the timer on the player."""
self.soco.set_sleep_timer(data[ATTR_SLEEP_TIME]) self.soco.set_sleep_timer(sleep_time)
@soco_error() @soco_error()
@soco_coordinator @soco_coordinator
def clear_sleep_timer(self, data): def clear_sleep_timer(self):
"""Clear the timer on the player.""" """Clear the timer on the player."""
self.soco.set_sleep_timer(None) self.soco.set_sleep_timer(None)
@soco_error() @soco_error()
@soco_coordinator @soco_coordinator
def set_alarm(self, data): def set_alarm(
self, alarm_id, time=None, volume=None, enabled=None, include_linked_zones=None
):
"""Set the alarm clock on the player.""" """Set the alarm clock on the player."""
alarm = None alarm = None
for one_alarm in alarms.get_alarms(self.soco): for one_alarm in alarms.get_alarms(self.soco):
# pylint: disable=protected-access # pylint: disable=protected-access
if one_alarm._alarm_id == str(data[ATTR_ALARM_ID]): if one_alarm._alarm_id == str(alarm_id):
alarm = one_alarm alarm = one_alarm
if alarm is None: if alarm is None:
_LOGGER.warning("did not find alarm with id %s", data[ATTR_ALARM_ID]) _LOGGER.warning("did not find alarm with id %s", alarm_id)
return return
if ATTR_TIME in data: if time is not None:
alarm.start_time = data[ATTR_TIME] alarm.start_time = time
if ATTR_VOLUME in data: if volume is not None:
alarm.volume = int(data[ATTR_VOLUME] * 100) alarm.volume = int(volume * 100)
if ATTR_ENABLED in data: if enabled is not None:
alarm.enabled = data[ATTR_ENABLED] alarm.enabled = enabled
if ATTR_INCLUDE_LINKED_ZONES in data: if include_linked_zones is not None:
alarm.include_linked_zones = data[ATTR_INCLUDE_LINKED_ZONES] alarm.include_linked_zones = include_linked_zones
alarm.save() alarm.save()
@soco_error() @soco_error()
def set_option(self, data): def set_option(self, night_sound=None, speech_enhance=None):
"""Modify playback options.""" """Modify playback options."""
if ATTR_NIGHT_SOUND in data and self._night_sound is not None: if night_sound is not None and self._night_sound is not None:
self.soco.night_mode = data[ATTR_NIGHT_SOUND] self.soco.night_mode = night_sound
if ATTR_SPEECH_ENHANCE in data and self._speech_enhance is not None: if speech_enhance is not None and self._speech_enhance is not None:
self.soco.dialog_mode = data[ATTR_SPEECH_ENHANCE] self.soco.dialog_mode = speech_enhance
@soco_error() @soco_error()
def play_queue(self, data): def play_queue(self, queue_position=0):
"""Start playing the queue.""" """Start playing the queue."""
self.soco.play_from_queue(data[ATTR_QUEUE_POSITION]) self.soco.play_from_queue(queue_position)
@property @property
def device_state_attributes(self): def device_state_attributes(self):

View file

@ -6,17 +6,15 @@ import logging
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL
ATTR_ENTITY_ID,
CONF_ENTITY_NAMESPACE,
CONF_SCAN_INTERVAL,
ENTITY_MATCH_ALL,
)
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import (
from homeassistant.helpers.config_validation import make_entity_service_schema config_per_platform,
from homeassistant.helpers.service import async_extract_entity_ids config_validation as cv,
discovery,
service,
)
from homeassistant.loader import async_get_integration, bind_hass from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
@ -166,39 +164,27 @@ class EntityComponent:
await platform.async_reset() await platform.async_reset()
return True return True
async def async_extract_from_service(self, service, expand_group=True): async def async_extract_from_service(self, service_call, expand_group=True):
"""Extract all known and available entities from a service call. """Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown. Will return an empty list if entities specified but unknown.
This method must be run in the event loop. This method must be run in the event loop.
""" """
data_ent_id = service.data.get(ATTR_ENTITY_ID) return await service.async_extract_entities(
self.hass, self.entities, service_call, expand_group
if data_ent_id is None: )
return []
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in self.entities if entity.available]
entity_ids = await async_extract_entity_ids(self.hass, service, expand_group)
return [
entity
for entity in self.entities
if entity.available and entity.entity_id in entity_ids
]
@callback @callback
def async_register_entity_service(self, name, schema, func, required_features=None): def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service.""" """Register an entity service."""
if isinstance(schema, dict): if isinstance(schema, dict):
schema = make_entity_service_schema(schema) schema = cv.make_entity_service_schema(schema)
async def handle_service(call): async def handle_service(call):
"""Handle the service.""" """Handle the service."""
service_name = f"{self.domain}.{name}"
await self.hass.helpers.service.entity_service_call( await self.hass.helpers.service.entity_service_call(
self._platforms.values(), func, call, service_name, required_features self._platforms.values(), func, call, required_features
) )
self.hass.services.async_register(self.domain, name, handle_service, schema) self.hass.services.async_register(self.domain, name, handle_service, schema)

View file

@ -7,6 +7,7 @@ from typing import Optional
from homeassistant.const import DEVICE_DEFAULT_NAME from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import config_validation as cv, service
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from .entity_registry import DISABLED_INTEGRATION from .entity_registry import DISABLED_INTEGRATION
@ -194,7 +195,11 @@ class EntityPlatform:
) )
return False return False
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception("Error while setting up platform %s", self.platform_name) logger.exception(
"Error while setting up %s platform for %s",
self.platform_name,
self.domain,
)
return False return False
finally: finally:
warn_task.cancel() warn_task.cancel()
@ -449,6 +454,33 @@ class EntityPlatform:
self._async_unsub_polling() self._async_unsub_polling()
self._async_unsub_polling = None self._async_unsub_polling = None
async def async_extract_from_service(self, service_call, expand_group=True):
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
self.hass, self.entities.values(), service_call, expand_group
)
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service."""
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call):
"""Handle the service."""
await service.entity_service_call(
self.hass, [self], func, call, required_features
)
self.hass.services.async_register(
self.platform_name, name, handle_service, schema
)
async def _update_entity_states(self, now: datetime) -> None: async def _update_entity_states(self, now: datetime) -> None:
"""Update the states of all the polling entities. """Update the states of all the polling entities.

View file

@ -108,13 +108,31 @@ def extract_entity_ids(hass, service_call, expand_group=True):
).result() ).result()
@bind_hass
async def async_extract_entities(hass, entities, service_call, expand_group=True):
"""Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents.
"""
data_ent_id = service_call.data.get(ATTR_ENTITY_ID)
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available]
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group)
return [
entity
for entity in entities
if entity.available and entity.entity_id in entity_ids
]
@bind_hass @bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True): async def async_extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call. """Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
Async friendly.
""" """
entity_ids = service_call.data.get(ATTR_ENTITY_ID) entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID) area_ids = service_call.data.get(ATTR_AREA_ID)
@ -244,9 +262,7 @@ def async_set_service_schema(hass, domain, service, schema):
@bind_hass @bind_hass
async def entity_service_call( async def entity_service_call(hass, platforms, func, call, required_features=None):
hass, platforms, func, call, service_name="", required_features=None
):
"""Handle an entity service call. """Handle an entity service call.
Calls all platforms simultaneously. Calls all platforms simultaneously.

View file

@ -23,6 +23,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
MockEntity,
get_test_home_assistant, get_test_home_assistant,
mock_coro, mock_coro,
mock_device_registry, mock_device_registry,
@ -64,6 +65,54 @@ def mock_entities():
return entities return entities
@pytest.fixture
def area_mock(hass):
"""Mock including area info."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
},
)
entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
mock_registry(
hass,
{
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
},
)
class TestServiceHelpers(unittest.TestCase): class TestServiceHelpers(unittest.TestCase):
"""Test the Home Assistant service helpers.""" """Test the Home Assistant service helpers."""
@ -204,52 +253,8 @@ async def test_extract_entity_ids(hass):
) )
async def test_extract_entity_ids_from_area(hass): async def test_extract_entity_ids_from_area(hass, area_mock):
"""Test extract_entity_ids method with areas.""" """Test extract_entity_ids method with areas."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
},
)
entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
mock_registry(
hass,
{
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
},
)
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"}) call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call) assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call)
@ -678,3 +683,86 @@ async def test_domain_control_no_user(hass, mock_entities):
) )
assert len(calls) == 1 assert len(calls) == 1
async def test_extract_from_service_available_device(hass):
"""Test the extraction of entity from service and device is available."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2", available=False),
MockEntity(name="test_3", entity_id="test_domain.test_3"),
MockEntity(name="test_4", entity_id="test_domain.test_4", available=False),
]
call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL})
assert ["test_domain.test_1", "test_domain.test_3"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call_1))
]
call_2 = ha.ServiceCall(
"test",
"service",
data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]},
)
assert ["test_domain.test_3"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call_2))
]
async def test_extract_from_service_empty_if_no_entity_id(hass):
"""Test the extraction from service without specifying entity."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2"),
]
call = ha.ServiceCall("test", "service")
assert [] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call))
]
async def test_extract_from_service_filter_out_non_existing_entities(hass):
"""Test the extraction of non existing entities from service."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2"),
]
call = ha.ServiceCall(
"test",
"service",
{"entity_id": ["test_domain.test_2", "test_domain.non_exist"]},
)
assert ["test_domain.test_2"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call))
]
async def test_extract_from_service_area_id(hass, area_mock):
"""Test the extraction using area ID as reference."""
entities = [
MockEntity(name="in_area", entity_id="light.in_area"),
MockEntity(name="no_area", entity_id="light.no_area"),
MockEntity(name="diff_area", entity_id="light.diff_area"),
]
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
extracted = await service.async_extract_entities(hass, entities, call)
assert len(extracted) == 1
assert extracted[0].entity_id == "light.in_area"
call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]})
extracted = await service.async_extract_entities(hass, entities, call)
assert len(extracted) == 2
assert sorted(ent.entity_id for ent in extracted) == [
"light.diff_area",
"light.in_area",
]