Do sanity check in EntityComponent.async_register_entity_service schema (#124029)
* Do sanity check in EntityComponent.async_register_entity_service schema * Improve test
This commit is contained in:
parent
183c191d63
commit
799e95c1bd
2 changed files with 42 additions and 0 deletions
|
@ -11,6 +11,7 @@ from types import ModuleType
|
||||||
from typing import Any, Generic
|
from typing import Any, Generic
|
||||||
|
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config as conf_util
|
from homeassistant import config as conf_util
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
@ -266,6 +267,20 @@ class EntityComponent(Generic[_EntityT]):
|
||||||
"""Register an entity service."""
|
"""Register an entity service."""
|
||||||
if schema is None or isinstance(schema, dict):
|
if schema is None or isinstance(schema, dict):
|
||||||
schema = cv.make_entity_service_schema(schema)
|
schema = cv.make_entity_service_schema(schema)
|
||||||
|
# Do a sanity check to check this is a valid entity service schema,
|
||||||
|
# the check could be extended to require All/Any to have sub schema(s)
|
||||||
|
# with all entity service fields
|
||||||
|
elif (
|
||||||
|
# Don't check All/Any
|
||||||
|
not isinstance(schema, (vol.All, vol.Any))
|
||||||
|
# Don't check All/Any wrapped in schema
|
||||||
|
and not isinstance(schema.schema, (vol.All, vol.Any))
|
||||||
|
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
||||||
|
):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"The schema does not include all required keys: "
|
||||||
|
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
||||||
|
)
|
||||||
|
|
||||||
service_func: str | HassJob[..., Any]
|
service_func: str | HassJob[..., Any]
|
||||||
service_func = func if isinstance(func, str) else HassJob(func)
|
service_func = func if isinstance(func, str) else HassJob(func)
|
||||||
|
|
|
@ -556,6 +556,33 @@ async def test_register_entity_service(
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_entity_service_non_entity_service_schema(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test attempting to register a service with an incomplete schema."""
|
||||||
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match=(
|
||||||
|
"The schema does not include all required keys: entity_id, device_id, area_id, "
|
||||||
|
"floor_id, label_id"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"hello", vol.Schema({"some": str}), Mock()
|
||||||
|
)
|
||||||
|
|
||||||
|
# The check currently does not recurse into vol.All or vol.Any allowing these
|
||||||
|
# non-compliatn schemas to pass
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"hello", vol.All(vol.Schema({"some": str})), Mock()
|
||||||
|
)
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"hello", vol.Any(vol.Schema({"some": str})), Mock()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
||||||
"""Test an entity service that does support response data."""
|
"""Test an entity service that does support response data."""
|
||||||
entity = MockEntity(entity_id=f"{DOMAIN}.entity")
|
entity = MockEntity(entity_id=f"{DOMAIN}.entity")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue