Use entity.async_request_call in service helper (#31454)

* Use entity.async_request_call in service helper

* Clean up semaphore handling

* Address comments

* Simplify call entity service helper

* Fix stupid rflink test
This commit is contained in:
Paulus Schoutsen 2020-02-04 15:30:15 -08:00 committed by GitHub
parent 2c439af165
commit e970177eeb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 211 additions and 193 deletions

View file

@ -23,6 +23,8 @@ from . import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
TYPE_STANDARD = "standard" TYPE_STANDARD = "standard"
TYPE_INVERTED = "inverted" TYPE_INVERTED = "inverted"

View file

@ -31,6 +31,8 @@ from . import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
TYPE_DIMMABLE = "dimmable" TYPE_DIMMABLE = "dimmable"
TYPE_SWITCHABLE = "switchable" TYPE_SWITCHABLE = "switchable"
TYPE_HYBRID = "hybrid" TYPE_HYBRID = "hybrid"

View file

@ -22,6 +22,8 @@ from . import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{ {
vol.Optional( vol.Optional(

View file

@ -568,7 +568,6 @@ class Entity(ABC):
# call an requests # call an requests
async def async_request_call(self, coro): async def async_request_call(self, coro):
"""Process request batched.""" """Process request batched."""
if self.parallel_updates: if self.parallel_updates:
await self.parallel_updates.acquire() await self.parallel_updates.acquire()

View file

@ -62,22 +62,42 @@ class EntityPlatform:
# Platform is None for the EntityComponent "catch-all" EntityPlatform # Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities # which powers entity_component.add_entities
if platform is None: if platform is None:
self.parallel_updates = None self.parallel_updates_created = True
self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None self.parallel_updates: Optional[asyncio.Semaphore] = None
return return
self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None) self.parallel_updates_created = False
# semaphore will be created on demand self.parallel_updates = None
self.parallel_updates_semaphore = None
def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore: @callback
"""Get or create a semaphore for parallel updates.""" def _get_parallel_updates_semaphore(
if self.parallel_updates_semaphore is None: self, entity_has_async_update: bool
self.parallel_updates_semaphore = asyncio.Semaphore( ) -> Optional[asyncio.Semaphore]:
self.parallel_updates if self.parallel_updates else 1, """Get or create a semaphore for parallel updates.
loop=self.hass.loop,
) Semaphore will be created on demand because we base it off if update method is async or not.
return self.parallel_updates_semaphore
If parallel updates is set to 0, we skip the semaphore.
If parallel updates is set to a number, we initialize the semaphore to that number.
Default for entities with `async_update` method is 1. Otherwise it's 0.
"""
if self.parallel_updates_created:
return self.parallel_updates
self.parallel_updates_created = True
parallel_updates = getattr(self.platform, "PARALLEL_UPDATES", None)
if parallel_updates is None and not entity_has_async_update:
parallel_updates = 1
if parallel_updates == 0:
parallel_updates = None
if parallel_updates is not None:
self.parallel_updates = asyncio.Semaphore(parallel_updates)
return self.parallel_updates
async def async_setup(self, platform_config, discovery_info=None): async def async_setup(self, platform_config, discovery_info=None):
"""Set up the platform from a config file.""" """Set up the platform from a config file."""
@ -282,21 +302,9 @@ class EntityPlatform:
entity.hass = self.hass entity.hass = self.hass
entity.platform = self entity.platform = self
entity.parallel_updates = self._get_parallel_updates_semaphore(
# Async entity hasattr(entity, "async_update")
# PARALLEL_UPDATES == None: entity.parallel_updates = None )
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
# Sync entity
# PARALLEL_UPDATES == None: entity.parallel_updates = Semaphore(1)
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
if hasattr(entity, "async_update") and not self.parallel_updates:
entity.parallel_updates = None
elif not hasattr(entity, "async_update") and self.parallel_updates == 0:
entity.parallel_updates = None
else:
entity.parallel_updates = self._get_parallel_updates_semaphore()
# Update properties before we generate the entity_id # Update properties before we generate the entity_id
if update_before_add: if update_before_add:

View file

@ -316,16 +316,15 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# Check the permissions # Check the permissions
# A list with for each platform in platforms a list of entities to call # A list with entities to call the service on.
# the service on. entity_candidates = []
platforms_entities = []
if entity_perms is None: if entity_perms is None:
for platform in platforms: for platform in platforms:
if target_all_entities: if target_all_entities:
platforms_entities.append(list(platform.entities.values())) entity_candidates.extend(platform.entities.values())
else: else:
platforms_entities.append( entity_candidates.extend(
[ [
entity entity
for entity in platform.entities.values() for entity in platform.entities.values()
@ -337,7 +336,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# If we target all entities, we will select all entities the user # If we target all entities, we will select all entities the user
# is allowed to control. # is allowed to control.
for platform in platforms: for platform in platforms:
platforms_entities.append( entity_candidates.extend(
[ [
entity entity
for entity in platform.entities.values() for entity in platform.entities.values()
@ -362,39 +361,20 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
platform_entities.append(entity) platform_entities.append(entity)
platforms_entities.append(platform_entities) entity_candidates.extend(platform_entities)
if not target_all_entities: if not target_all_entities:
for platform_entities in platforms_entities: for entity in entity_candidates:
for entity in platform_entities: entity_ids.remove(entity.entity_id)
entity_ids.remove(entity.entity_id)
if entity_ids: if entity_ids:
_LOGGER.warning( _LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids)) "Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
) )
tasks = [ entities = []
_handle_service_platform_call(
hass, func, data, entities, call.context, required_features
)
for platform, entities in zip(platforms, platforms_entities)
]
if tasks: for entity in entity_candidates:
done, pending = await asyncio.wait(tasks)
assert not pending
for future in done:
future.result() # pop exception if have
async def _handle_service_platform_call(
hass, func, data, entities, context, required_features
):
"""Handle a function call."""
tasks = []
for entity in entities:
if not entity.available: if not entity.available:
continue continue
@ -404,27 +384,33 @@ async def _handle_service_platform_call(
): ):
continue continue
entity.async_set_context(context) entities.append(entity)
if isinstance(func, str): if not entities:
result = hass.async_add_job(partial(getattr(entity, func), **data)) return
else:
result = hass.async_add_job(func, entity, data)
# Guard because callback functions do not return a task when passed to async_add_job. done, pending = await asyncio.wait(
if result is not None: [
result = await result entity.async_request_call(
_handle_entity_call(hass, entity, func, data, call.context)
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
) )
await result for entity in entities
]
)
assert not pending
for future in done:
future.result() # pop exception if have
if entity.should_poll: tasks = []
tasks.append(entity.async_update_ha_state(True))
for entity in entities:
if not entity.should_poll:
continue
# Context expires if the turn on commands took a long time.
# Set context again so it's there when we update
entity.async_set_context(call.context)
tasks.append(entity.async_update_ha_state(True))
if tasks: if tasks:
done, pending = await asyncio.wait(tasks) done, pending = await asyncio.wait(tasks)
@ -433,6 +419,28 @@ async def _handle_service_platform_call(
future.result() # pop exception if have future.result() # pop exception if have
async def _handle_entity_call(hass, entity, func, data, context):
"""Handle calling service method."""
entity.async_set_context(context)
if isinstance(func, str):
result = hass.async_add_job(partial(getattr(entity, func), **data))
else:
result = hass.async_add_job(func, entity, data)
# Guard because callback functions do not return a task when passed to async_add_job.
if result is not None:
await result
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
)
await result
@bind_hass @bind_hass
@ha.callback @ha.callback
def async_register_admin_service( def async_register_admin_service(
@ -474,6 +482,7 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
return await service_handler(call) return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None:
raise UnknownUser( raise UnknownUser(
context=call.context, context=call.context,
@ -482,14 +491,12 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
) )
reg = await hass.helpers.entity_registry.async_get_registry() reg = await hass.helpers.entity_registry.async_get_registry()
entities = [
entity.entity_id
for entity in reg.entities.values()
if entity.platform == domain
]
for entity_id in entities: for entity in reg.entities.values():
if user.permissions.check_entity(entity_id, POLICY_CONTROL): if entity.platform != domain:
continue
if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL):
return await service_handler(call) return await service_handler(call)
raise Unauthorized( raise Unauthorized(

View file

@ -270,8 +270,6 @@ async def test_parallel_updates_async_platform_with_constant(hass):
handle = list(component._platforms.values())[-1] handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
class AsyncEntity(MockEntity): class AsyncEntity(MockEntity):
"""Mock entity that has async_update.""" """Mock entity that has async_update."""
@ -296,7 +294,6 @@ async def test_parallel_updates_sync_platform(hass):
await component.async_setup({DOMAIN: {"platform": "platform"}}) await component.async_setup({DOMAIN: {"platform": "platform"}})
handle = list(component._platforms.values())[-1] handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
class SyncEntity(MockEntity): class SyncEntity(MockEntity):
"""Mock entity that has update.""" """Mock entity that has update."""
@ -323,7 +320,6 @@ async def test_parallel_updates_sync_platform_with_constant(hass):
await component.async_setup({DOMAIN: {"platform": "platform"}}) await component.async_setup({DOMAIN: {"platform": "platform"}})
handle = list(component._platforms.values())[-1] handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
class SyncEntity(MockEntity): class SyncEntity(MockEntity):
"""Mock entity that has update.""" """Mock entity that has update."""

View file

@ -39,31 +39,29 @@ from tests.common import (
@pytest.fixture @pytest.fixture
def mock_service_platform_call(): def mock_handle_entity_call():
"""Mock service platform call.""" """Mock service platform call."""
with patch( with patch(
"homeassistant.helpers.service._handle_service_platform_call", "homeassistant.helpers.service._handle_entity_call",
side_effect=lambda *args: mock_coro(), side_effect=lambda *args: mock_coro(),
) as mock_call: ) as mock_call:
yield mock_call yield mock_call
@pytest.fixture @pytest.fixture
def mock_entities(): def mock_entities(hass):
"""Return mock entities in an ordered dict.""" """Return mock entities in an ordered dict."""
kitchen = Mock( kitchen = MockEntity(
entity_id="light.kitchen", entity_id="light.kitchen",
available=True, available=True,
should_poll=False, should_poll=False,
supported_features=1, supported_features=1,
platform="test_domain",
) )
living_room = Mock( living_room = MockEntity(
entity_id="light.living_room", entity_id="light.living_room",
available=True, available=True,
should_poll=False, should_poll=False,
supported_features=0, supported_features=0,
platform="test_domain",
) )
entities = OrderedDict() entities = OrderedDict()
entities[kitchen.entity_id] = kitchen entities[kitchen.entity_id] = kitchen
@ -374,7 +372,7 @@ async def test_call_context_user_not_exist(hass):
assert err.value.context.user_id == "non-existing" assert err.value.context.user_id == "non-existing"
async def test_call_context_target_all(hass, mock_service_platform_call, mock_entities): async def test_call_context_target_all(hass, mock_handle_entity_call, mock_entities):
"""Check we only target allowed entities if targeting all.""" """Check we only target allowed entities if targeting all."""
with patch( with patch(
"homeassistant.auth.AuthManager.async_get_user", "homeassistant.auth.AuthManager.async_get_user",
@ -398,13 +396,12 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
), ),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3] assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
assert entities == [mock_entities["light.kitchen"]]
async def test_call_context_target_specific( async def test_call_context_target_specific(
hass, mock_service_platform_call, mock_entities hass, mock_handle_entity_call, mock_entities
): ):
"""Check targeting specific entities.""" """Check targeting specific entities."""
with patch( with patch(
@ -429,13 +426,12 @@ async def test_call_context_target_specific(
), ),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3] assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
assert entities == [mock_entities["light.kitchen"]]
async def test_call_context_target_specific_no_auth( async def test_call_context_target_specific_no_auth(
hass, mock_service_platform_call, mock_entities hass, mock_handle_entity_call, mock_entities
): ):
"""Check targeting specific entities without auth.""" """Check targeting specific entities without auth."""
with pytest.raises(exceptions.Unauthorized) as err: with pytest.raises(exceptions.Unauthorized) as err:
@ -459,9 +455,7 @@ async def test_call_context_target_specific_no_auth(
assert err.value.entity_id == "light.kitchen" assert err.value.entity_id == "light.kitchen"
async def test_call_no_context_target_all( async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_entities):
hass, mock_service_platform_call, mock_entities
):
"""Check we target all if no user context given.""" """Check we target all if no user context given."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
@ -472,13 +466,14 @@ async def test_call_no_context_target_all(
), ),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 2
entities = mock_service_platform_call.mock_calls[0][1][3] assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
assert entities == list(mock_entities.values()) mock_entities.values()
)
async def test_call_no_context_target_specific( async def test_call_no_context_target_specific(
hass, mock_service_platform_call, mock_entities hass, mock_handle_entity_call, mock_entities
): ):
"""Check we can target specified entities.""" """Check we can target specified entities."""
await service.entity_service_call( await service.entity_service_call(
@ -492,13 +487,12 @@ async def test_call_no_context_target_specific(
), ),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3] assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
assert entities == [mock_entities["light.kitchen"]]
async def test_call_with_match_all( async def test_call_with_match_all(
hass, mock_service_platform_call, mock_entities, caplog hass, mock_handle_entity_call, mock_entities, caplog
): ):
"""Check we only target allowed entities if targeting all.""" """Check we only target allowed entities if targeting all."""
await service.entity_service_call( await service.entity_service_call(
@ -508,20 +502,13 @@ async def test_call_with_match_all(
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 2
entities = mock_service_platform_call.mock_calls[0][1][3] assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
assert entities == [ mock_entities.values()
mock_entities["light.kitchen"], )
mock_entities["light.living_room"],
]
assert (
"Not passing an entity ID to a service to target all entities is deprecated"
) not in caplog.text
async def test_call_with_omit_entity_id( async def test_call_with_omit_entity_id(hass, mock_handle_entity_call, mock_entities):
hass, mock_service_platform_call, mock_entities
):
"""Check service call if we do not pass an entity ID.""" """Check service call if we do not pass an entity ID."""
await service.entity_service_call( await service.entity_service_call(
hass, hass,
@ -530,9 +517,7 @@ async def test_call_with_omit_entity_id(
ha.ServiceCall("test_domain", "test_service"), ha.ServiceCall("test_domain", "test_service"),
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_handle_entity_call.mock_calls) == 0
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == []
async def test_register_admin_service(hass, hass_read_only_user, hass_admin_user): async def test_register_admin_service(hass, hass_read_only_user, hass_admin_user):
@ -644,96 +629,113 @@ async def test_domain_control_unknown(hass, mock_entities):
assert len(calls) == 0 assert len(calls) == 0
async def test_domain_control_unauthorized(hass, hass_read_only_user, mock_entities): async def test_domain_control_unauthorized(hass, hass_read_only_user):
"""Test domain verification in a service call with an unauthorized user.""" """Test domain verification in a service call with an unauthorized user."""
calls = [] mock_registry(
hass,
async def mock_service_log(call): {
"""Define a protected service.""" "light.kitchen": ent_reg.RegistryEntry(
calls.append(call) entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
with patch(
"homeassistant.helpers.entity_registry.async_get_registry",
return_value=mock_coro(Mock(entities=mock_entities)),
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id),
) )
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id),
)
assert len(calls) == 0
async def test_domain_control_admin(hass, hass_admin_user, mock_entities): async def test_domain_control_admin(hass, hass_admin_user):
"""Test domain verification in a service call with an admin user.""" """Test domain verification in a service call with an admin user."""
mock_registry(
hass,
{
"light.kitchen": ent_reg.RegistryEntry(
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
)
},
)
calls = [] calls = []
async def mock_service_log(call): async def mock_service_log(call):
"""Define a protected service.""" """Define a protected service."""
calls.append(call) calls.append(call)
with patch( protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
"homeassistant.helpers.entity_registry.async_get_registry", mock_service_log
return_value=mock_coro(Mock(entities=mock_entities)), )
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
) )
await hass.services.async_call( await hass.services.async_call(
"test_domain", "test_domain",
"test_service", "test_service",
{}, {},
blocking=True, blocking=True,
context=ha.Context(user_id=hass_admin_user.id), context=ha.Context(user_id=hass_admin_user.id),
) )
assert len(calls) == 1 assert len(calls) == 1
async def test_domain_control_no_user(hass, mock_entities): async def test_domain_control_no_user(hass):
"""Test domain verification in a service call with no user.""" """Test domain verification in a service call with no user."""
mock_registry(
hass,
{
"light.kitchen": ent_reg.RegistryEntry(
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
)
},
)
calls = [] calls = []
async def mock_service_log(call): async def mock_service_log(call):
"""Define a protected service.""" """Define a protected service."""
calls.append(call) calls.append(call)
with patch( protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
"homeassistant.helpers.entity_registry.async_get_registry", mock_service_log
return_value=mock_coro(Mock(entities=mock_entities)), )
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
hass.services.async_register( hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None "test_domain", "test_service", protected_mock_service, schema=None
) )
await hass.services.async_call( await hass.services.async_call(
"test_domain", "test_domain",
"test_service", "test_service",
{}, {},
blocking=True, blocking=True,
context=ha.Context(user_id=None), context=ha.Context(user_id=None),
) )
assert len(calls) == 1 assert len(calls) == 1
async def test_extract_from_service_available_device(hass): async def test_extract_from_service_available_device(hass):