Cache the job type for entity service calls (#112793)
This commit is contained in:
parent
19e54debba
commit
b7d9f26cee
4 changed files with 53 additions and 4 deletions
|
@ -299,7 +299,7 @@ class HassJob(Generic[_P, _R_co]):
|
|||
@cached_property
|
||||
def job_type(self) -> HassJobType:
|
||||
"""Return the job type."""
|
||||
return self._job_type or _get_hassjob_callable_job_type(self.target)
|
||||
return self._job_type or get_hassjob_callable_job_type(self.target)
|
||||
|
||||
@property
|
||||
def cancel_on_shutdown(self) -> bool | None:
|
||||
|
@ -319,7 +319,7 @@ class HassJobWithArgs:
|
|||
args: Iterable[Any]
|
||||
|
||||
|
||||
def _get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
|
||||
def get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
|
||||
"""Determine the job type from the callable."""
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_target = target
|
||||
|
|
|
@ -51,6 +51,7 @@ from homeassistant.core import (
|
|||
HassJobType,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
get_hassjob_callable_job_type,
|
||||
get_release_channel,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
|
@ -527,6 +528,8 @@ class Entity(
|
|||
__combined_unrecorded_attributes: frozenset[str] = (
|
||||
_entity_component_unrecorded_attributes | _unrecorded_attributes
|
||||
)
|
||||
# Job type cache
|
||||
_job_types: dict[str, HassJobType] | None = None
|
||||
|
||||
# StateInfo. Set by EntityPlatform by calling async_internal_added_to_hass
|
||||
# While not purely typed, it makes typehinting more useful for us
|
||||
|
@ -568,6 +571,20 @@ class Entity(
|
|||
cls._entity_component_unrecorded_attributes | cls._unrecorded_attributes
|
||||
)
|
||||
|
||||
def get_hassjob_type(self, function_name: str) -> HassJobType:
|
||||
"""Get the job type function for the given name.
|
||||
|
||||
This is used for entity service calls to avoid
|
||||
figuring out the job type each time.
|
||||
"""
|
||||
if not self._job_types:
|
||||
self._job_types = {}
|
||||
if function_name not in self._job_types:
|
||||
self._job_types[function_name] = get_hassjob_callable_job_type(
|
||||
getattr(self, function_name)
|
||||
)
|
||||
return self._job_types[function_name]
|
||||
|
||||
@cached_property
|
||||
def should_poll(self) -> bool:
|
||||
"""Return True if entity has to be polled for state.
|
||||
|
|
|
@ -963,7 +963,10 @@ async def _handle_entity_call(
|
|||
|
||||
task: asyncio.Future[ServiceResponse] | None
|
||||
if isinstance(func, str):
|
||||
job = HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type]
|
||||
job = HassJob(
|
||||
partial(getattr(entity, func), **data), # type: ignore[arg-type]
|
||||
job_type=entity.get_hassjob_type(func),
|
||||
)
|
||||
task = hass.async_run_hass_job(job, eager_start=True)
|
||||
else:
|
||||
task = hass.async_run_hass_job(func, entity, data, eager_start=True)
|
||||
|
|
|
@ -23,7 +23,13 @@ from homeassistant.const import (
|
|||
STATE_UNAVAILABLE,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, HomeAssistantError
|
||||
from homeassistant.core import (
|
||||
Context,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
HomeAssistantError,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import device_registry as dr, entity, entity_registry as er
|
||||
from homeassistant.helpers.entity_component import async_update_entity
|
||||
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||
|
@ -2559,3 +2565,26 @@ async def test_reset_right_after_remove_entity_registry(
|
|||
assert len(ent.remove_calls) == 1
|
||||
|
||||
assert hass.states.get("test.test") is None
|
||||
|
||||
|
||||
async def test_get_hassjob_type(hass: HomeAssistant) -> None:
|
||||
"""Test get_hassjob_type."""
|
||||
|
||||
class AsyncEntity(entity.Entity):
|
||||
"""Test entity."""
|
||||
|
||||
def update(self):
|
||||
"""Test update Executor."""
|
||||
|
||||
async def async_update(self):
|
||||
"""Test update Coroutinefunction."""
|
||||
|
||||
@callback
|
||||
def update_callback(self):
|
||||
"""Test update Callback."""
|
||||
|
||||
ent_1 = AsyncEntity()
|
||||
|
||||
assert ent_1.get_hassjob_type("update") is HassJobType.Executor
|
||||
assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction
|
||||
assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback
|
||||
|
|
Loading…
Add table
Reference in a new issue