Improve validation of entity service schemas (#124102)

* Improve validation of entity service schemas

* Update tests/helpers/test_entity_platform.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
Erik Montnemery 2024-08-27 19:05:49 +02:00 committed by GitHub
parent 0dc1eb8757
commit 55c42fde88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 87 additions and 49 deletions

View file

@ -1305,9 +1305,28 @@ TARGET_SERVICE_FIELDS = {
_HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS) _HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS)
def is_entity_service_schema(validator: VolSchemaType) -> bool:
"""Check if the passed validator is an entity schema validator.
The validator must be either of:
- A validator returned by cv._make_entity_service_schema
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.All
Nesting is allowed.
"""
if hasattr(validator, "_entity_service_schema"):
return True
if isinstance(validator, (vol.All)):
return any(is_entity_service_schema(val) for val in validator.validators)
if isinstance(validator, (vol.Schema)):
return is_entity_service_schema(validator.schema)
return False
def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType: def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
"""Create an entity service schema.""" """Create an entity service schema."""
return vol.All( validator = vol.All(
vol.Schema( vol.Schema(
{ {
# The frontend stores data here. Don't use in core. # The frontend stores data here. Don't use in core.
@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
), ),
_HAS_ENTITY_SERVICE_FIELD, _HAS_ENTITY_SERVICE_FIELD,
) )
setattr(validator, "_entity_service_schema", True)
return validator
BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA) BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA)

View file

@ -1267,17 +1267,8 @@ def async_register_entity_service(
# Do a sanity check to check this is a valid entity service schema, # Do a sanity check to check this is a valid entity service schema,
# the check could be extended to require All/Any to have sub schema(s) # the check could be extended to require All/Any to have sub schema(s)
# with all entity service fields # with all entity service fields
elif ( elif not cv.is_entity_service_schema(schema):
# Don't check All/Any raise HomeAssistantError("The schema is not an entity service schema")
not isinstance(schema, (vol.All, vol.Any))
# Don't check All/Any wrapped in schema
and not isinstance(schema.schema, (vol.All, vol.Any))
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
):
raise HomeAssistantError(
"The schema does not include all required keys: "
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
)
service_func: str | HassJob[..., Any] service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func) service_func = func if isinstance(func, str) else HassJob(func)

View file

@ -1805,3 +1805,27 @@ async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> Non
"string": [hass.loop_thread_id], "string": [hass.loop_thread_id],
} }
validator_calls = {} validator_calls = {}
async def test_is_entity_service_schema(
hass: HomeAssistant,
) -> None:
"""Test cv.is_entity_service_schema."""
for schema in (
vol.Schema({"some": str}),
vol.All(vol.Schema({"some": str})),
vol.Any(vol.Schema({"some": str})),
vol.Any(cv.make_entity_service_schema({"some": str})),
):
assert cv.is_entity_service_schema(schema) is False
for schema in (
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))),
vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))),
vol.All(cv.make_entity_service_schema({"some": str})),
vol.All(vol.All(cv.make_entity_service_schema({"some": str}))),
vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))),
):
assert cv.is_entity_service_schema(schema) is True

View file

@ -23,7 +23,7 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import discovery from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.entity_component import EntityComponent, async_update_entity from homeassistant.helpers.entity_component import EntityComponent, async_update_entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
@ -559,28 +559,28 @@ async def test_register_entity_service(
async def test_register_entity_service_non_entity_service_schema( async def test_register_entity_service_non_entity_service_schema(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test attempting to register a service with an incomplete schema.""" """Test attempting to register a service with a non entity service schema."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
with pytest.raises( for schema in (
HomeAssistantError, vol.Schema({"some": str}),
match=( vol.All(vol.Schema({"some": str})),
"The schema does not include all required keys: entity_id, device_id, area_id, " vol.Any(vol.Schema({"some": str})),
"floor_id, label_id"
),
): ):
component.async_register_entity_service( with pytest.raises(
"hello", vol.Schema({"some": str}), Mock() HomeAssistantError,
) match=("The schema is not an entity service schema"),
):
component.async_register_entity_service("hello", schema, Mock())
# The check currently does not recurse into vol.All or vol.Any allowing these for idx, schema in enumerate(
# non-compliant schemas to pass (
component.async_register_entity_service( cv.make_entity_service_schema({"some": str}),
"hello", vol.All(vol.Schema({"some": str})), Mock() vol.Schema(cv.make_entity_service_schema({"some": str})),
) vol.All(cv.make_entity_service_schema({"some": str})),
component.async_register_entity_service( )
"hello", vol.Any(vol.Schema({"some": str})), Mock() ):
) component.async_register_entity_service(f"test_service_{idx}", schema, Mock())
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None: async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:

View file

@ -23,6 +23,7 @@ from homeassistant.core import (
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
config_validation as cv,
device_registry as dr, device_registry as dr,
entity_platform, entity_platform,
entity_registry as er, entity_registry as er,
@ -1812,31 +1813,32 @@ async def test_register_entity_service_none_schema(
async def test_register_entity_service_non_entity_service_schema( async def test_register_entity_service_non_entity_service_schema(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test attempting to register a service with an incomplete schema.""" """Test attempting to register a service with a non entity service schema."""
entity_platform = MockEntityPlatform( entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None hass, domain="mock_integration", platform_name="mock_platform", platform=None
) )
with pytest.raises( for schema in (
HomeAssistantError, vol.Schema({"some": str}),
match=( vol.All(vol.Schema({"some": str})),
"The schema does not include all required keys: entity_id, device_id, area_id, " vol.Any(vol.Schema({"some": str})),
"floor_id, label_id" ):
), with pytest.raises(
HomeAssistantError,
match="The schema is not an entity service schema",
):
entity_platform.async_register_entity_service("hello", schema, Mock())
for idx, schema in enumerate(
(
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.All(cv.make_entity_service_schema({"some": str})),
)
): ):
entity_platform.async_register_entity_service( entity_platform.async_register_entity_service(
"hello", f"test_service_{idx}", schema, Mock()
vol.Schema({"some": str}),
Mock(),
) )
# The check currently does not recurse into vol.All or vol.Any allowing these
# non-compliant schemas to pass
entity_platform.async_register_entity_service(
"hello", vol.All(vol.Schema({"some": str})), Mock()
)
entity_platform.async_register_entity_service(
"hello", vol.Any(vol.Schema({"some": str})), Mock()
)
@pytest.mark.parametrize("update_before_add", [True, False]) @pytest.mark.parametrize("update_before_add", [True, False])