Improve validation of entity service schemas (#124102)
* Improve validation of entity service schemas * Update tests/helpers/test_entity_platform.py Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> --------- Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
parent
0dc1eb8757
commit
55c42fde88
5 changed files with 87 additions and 49 deletions
|
@ -1305,9 +1305,28 @@ TARGET_SERVICE_FIELDS = {
|
||||||
_HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS)
|
_HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_entity_service_schema(validator: VolSchemaType) -> bool:
|
||||||
|
"""Check if the passed validator is an entity schema validator.
|
||||||
|
|
||||||
|
The validator must be either of:
|
||||||
|
- A validator returned by cv._make_entity_service_schema
|
||||||
|
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema
|
||||||
|
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.All
|
||||||
|
Nesting is allowed.
|
||||||
|
"""
|
||||||
|
if hasattr(validator, "_entity_service_schema"):
|
||||||
|
return True
|
||||||
|
if isinstance(validator, (vol.All)):
|
||||||
|
return any(is_entity_service_schema(val) for val in validator.validators)
|
||||||
|
if isinstance(validator, (vol.Schema)):
|
||||||
|
return is_entity_service_schema(validator.schema)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
|
def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
|
||||||
"""Create an entity service schema."""
|
"""Create an entity service schema."""
|
||||||
return vol.All(
|
validator = vol.All(
|
||||||
vol.Schema(
|
vol.Schema(
|
||||||
{
|
{
|
||||||
# The frontend stores data here. Don't use in core.
|
# The frontend stores data here. Don't use in core.
|
||||||
|
@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
|
||||||
),
|
),
|
||||||
_HAS_ENTITY_SERVICE_FIELD,
|
_HAS_ENTITY_SERVICE_FIELD,
|
||||||
)
|
)
|
||||||
|
setattr(validator, "_entity_service_schema", True)
|
||||||
|
return validator
|
||||||
|
|
||||||
|
|
||||||
BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA)
|
BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA)
|
||||||
|
|
|
@ -1267,17 +1267,8 @@ def async_register_entity_service(
|
||||||
# Do a sanity check to check this is a valid entity service schema,
|
# Do a sanity check to check this is a valid entity service schema,
|
||||||
# the check could be extended to require All/Any to have sub schema(s)
|
# the check could be extended to require All/Any to have sub schema(s)
|
||||||
# with all entity service fields
|
# with all entity service fields
|
||||||
elif (
|
elif not cv.is_entity_service_schema(schema):
|
||||||
# Don't check All/Any
|
raise HomeAssistantError("The schema is not an entity service schema")
|
||||||
not isinstance(schema, (vol.All, vol.Any))
|
|
||||||
# Don't check All/Any wrapped in schema
|
|
||||||
and not isinstance(schema.schema, (vol.All, vol.Any))
|
|
||||||
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
|
||||||
):
|
|
||||||
raise HomeAssistantError(
|
|
||||||
"The schema does not include all required keys: "
|
|
||||||
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
service_func: str | HassJob[..., Any]
|
service_func: str | HassJob[..., Any]
|
||||||
service_func = func if isinstance(func, str) else HassJob(func)
|
service_func = func if isinstance(func, str) else HassJob(func)
|
||||||
|
|
|
@ -1805,3 +1805,27 @@ async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> Non
|
||||||
"string": [hass.loop_thread_id],
|
"string": [hass.loop_thread_id],
|
||||||
}
|
}
|
||||||
validator_calls = {}
|
validator_calls = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_entity_service_schema(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test cv.is_entity_service_schema."""
|
||||||
|
for schema in (
|
||||||
|
vol.Schema({"some": str}),
|
||||||
|
vol.All(vol.Schema({"some": str})),
|
||||||
|
vol.Any(vol.Schema({"some": str})),
|
||||||
|
vol.Any(cv.make_entity_service_schema({"some": str})),
|
||||||
|
):
|
||||||
|
assert cv.is_entity_service_schema(schema) is False
|
||||||
|
|
||||||
|
for schema in (
|
||||||
|
cv.make_entity_service_schema({"some": str}),
|
||||||
|
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||||
|
vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))),
|
||||||
|
vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))),
|
||||||
|
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||||
|
vol.All(vol.All(cv.make_entity_service_schema({"some": str}))),
|
||||||
|
vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))),
|
||||||
|
):
|
||||||
|
assert cv.is_entity_service_schema(schema) is True
|
||||||
|
|
|
@ -23,7 +23,7 @@ from homeassistant.core import (
|
||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||||
from homeassistant.helpers import discovery
|
from homeassistant.helpers import config_validation as cv, discovery
|
||||||
from homeassistant.helpers.entity_component import EntityComponent, async_update_entity
|
from homeassistant.helpers.entity_component import EntityComponent, async_update_entity
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
@ -559,28 +559,28 @@ async def test_register_entity_service(
|
||||||
async def test_register_entity_service_non_entity_service_schema(
|
async def test_register_entity_service_non_entity_service_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test attempting to register a service with an incomplete schema."""
|
"""Test attempting to register a service with a non entity service schema."""
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||||
|
|
||||||
with pytest.raises(
|
for schema in (
|
||||||
HomeAssistantError,
|
vol.Schema({"some": str}),
|
||||||
match=(
|
vol.All(vol.Schema({"some": str})),
|
||||||
"The schema does not include all required keys: entity_id, device_id, area_id, "
|
vol.Any(vol.Schema({"some": str})),
|
||||||
"floor_id, label_id"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
component.async_register_entity_service(
|
with pytest.raises(
|
||||||
"hello", vol.Schema({"some": str}), Mock()
|
HomeAssistantError,
|
||||||
)
|
match=("The schema is not an entity service schema"),
|
||||||
|
):
|
||||||
|
component.async_register_entity_service("hello", schema, Mock())
|
||||||
|
|
||||||
# The check currently does not recurse into vol.All or vol.Any allowing these
|
for idx, schema in enumerate(
|
||||||
# non-compliant schemas to pass
|
(
|
||||||
component.async_register_entity_service(
|
cv.make_entity_service_schema({"some": str}),
|
||||||
"hello", vol.All(vol.Schema({"some": str})), Mock()
|
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||||
)
|
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||||
component.async_register_entity_service(
|
)
|
||||||
"hello", vol.Any(vol.Schema({"some": str})), Mock()
|
):
|
||||||
)
|
component.async_register_entity_service(f"test_service_{idx}", schema, Mock())
|
||||||
|
|
||||||
|
|
||||||
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
|
||||||
|
|
|
@ -23,6 +23,7 @@ from homeassistant.core import (
|
||||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_platform,
|
entity_platform,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
|
@ -1812,31 +1813,32 @@ async def test_register_entity_service_none_schema(
|
||||||
async def test_register_entity_service_non_entity_service_schema(
|
async def test_register_entity_service_non_entity_service_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test attempting to register a service with an incomplete schema."""
|
"""Test attempting to register a service with a non entity service schema."""
|
||||||
entity_platform = MockEntityPlatform(
|
entity_platform = MockEntityPlatform(
|
||||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
for schema in (
|
||||||
HomeAssistantError,
|
vol.Schema({"some": str}),
|
||||||
match=(
|
vol.All(vol.Schema({"some": str})),
|
||||||
"The schema does not include all required keys: entity_id, device_id, area_id, "
|
vol.Any(vol.Schema({"some": str})),
|
||||||
"floor_id, label_id"
|
):
|
||||||
),
|
with pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match="The schema is not an entity service schema",
|
||||||
|
):
|
||||||
|
entity_platform.async_register_entity_service("hello", schema, Mock())
|
||||||
|
|
||||||
|
for idx, schema in enumerate(
|
||||||
|
(
|
||||||
|
cv.make_entity_service_schema({"some": str}),
|
||||||
|
vol.Schema(cv.make_entity_service_schema({"some": str})),
|
||||||
|
vol.All(cv.make_entity_service_schema({"some": str})),
|
||||||
|
)
|
||||||
):
|
):
|
||||||
entity_platform.async_register_entity_service(
|
entity_platform.async_register_entity_service(
|
||||||
"hello",
|
f"test_service_{idx}", schema, Mock()
|
||||||
vol.Schema({"some": str}),
|
|
||||||
Mock(),
|
|
||||||
)
|
)
|
||||||
# The check currently does not recurse into vol.All or vol.Any allowing these
|
|
||||||
# non-compliant schemas to pass
|
|
||||||
entity_platform.async_register_entity_service(
|
|
||||||
"hello", vol.All(vol.Schema({"some": str})), Mock()
|
|
||||||
)
|
|
||||||
entity_platform.async_register_entity_service(
|
|
||||||
"hello", vol.Any(vol.Schema({"some": str})), Mock()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("update_before_add", [True, False])
|
@pytest.mark.parametrize("update_before_add", [True, False])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue