Cache the job type for entity service calls (#112793)

This commit is contained in:
J. Nick Koston 2024-03-08 22:49:08 -10:00 committed by GitHub
parent 19e54debba
commit b7d9f26cee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 4 deletions

View file

@ -299,7 +299,7 @@ class HassJob(Generic[_P, _R_co]):
@cached_property
def job_type(self) -> HassJobType:
"""Return the job type."""
return self._job_type or _get_hassjob_callable_job_type(self.target)
return self._job_type or get_hassjob_callable_job_type(self.target)
@property
def cancel_on_shutdown(self) -> bool | None:
@ -319,7 +319,7 @@ class HassJobWithArgs:
args: Iterable[Any]
def _get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
def get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
"""Determine the job type from the callable."""
# Check for partials to properly determine if coroutine function
check_target = target

View file

@ -51,6 +51,7 @@ from homeassistant.core import (
HassJobType,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
get_release_channel,
)
from homeassistant.exceptions import (
@ -527,6 +528,8 @@ class Entity(
__combined_unrecorded_attributes: frozenset[str] = (
_entity_component_unrecorded_attributes | _unrecorded_attributes
)
# Job type cache
_job_types: dict[str, HassJobType] | None = None
# StateInfo. Set by EntityPlatform by calling async_internal_added_to_hass
# While not purely typed, it makes typehinting more useful for us
@ -568,6 +571,20 @@ class Entity(
cls._entity_component_unrecorded_attributes | cls._unrecorded_attributes
)
def get_hassjob_type(self, function_name: str) -> HassJobType:
"""Get the job type function for the given name.
This is used for entity service calls to avoid
figuring out the job type each time.
"""
if not self._job_types:
self._job_types = {}
if function_name not in self._job_types:
self._job_types[function_name] = get_hassjob_callable_job_type(
getattr(self, function_name)
)
return self._job_types[function_name]
@cached_property
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.

View file

@ -963,7 +963,10 @@ async def _handle_entity_call(
task: asyncio.Future[ServiceResponse] | None
if isinstance(func, str):
job = HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type]
job = HassJob(
partial(getattr(entity, func), **data), # type: ignore[arg-type]
job_type=entity.get_hassjob_type(func),
)
task = hass.async_run_hass_job(job, eager_start=True)
else:
task = hass.async_run_hass_job(func, entity, data, eager_start=True)

View file

@ -23,7 +23,13 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import Context, HomeAssistant, HomeAssistantError
from homeassistant.core import (
Context,
HassJobType,
HomeAssistant,
HomeAssistantError,
callback,
)
from homeassistant.helpers import device_registry as dr, entity, entity_registry as er
from homeassistant.helpers.entity_component import async_update_entity
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
@ -2559,3 +2565,26 @@ async def test_reset_right_after_remove_entity_registry(
assert len(ent.remove_calls) == 1
assert hass.states.get("test.test") is None
async def test_get_hassjob_type(hass: HomeAssistant) -> None:
"""Test get_hassjob_type."""
class AsyncEntity(entity.Entity):
"""Test entity."""
def update(self):
"""Test update Executor."""
async def async_update(self):
"""Test update Coroutinefunction."""
@callback
def update_callback(self):
"""Test update Callback."""
ent_1 = AsyncEntity()
assert ent_1.get_hassjob_type("update") is HassJobType.Executor
assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction
assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback