Add async friendly helper for validating config schemas (#123800)
* Add async friendly helper for validating config schemas * Improve docstrings * Add tests
This commit is contained in:
parent
a7bca9bcea
commit
533442f33e
4 changed files with 161 additions and 7 deletions
|
@ -1535,7 +1535,9 @@ async def async_process_component_config(
|
|||
# No custom config validator, proceed with schema validation
|
||||
if hasattr(component, "CONFIG_SCHEMA"):
|
||||
try:
|
||||
return IntegrationConfigInfo(component.CONFIG_SCHEMA(config), [])
|
||||
return IntegrationConfigInfo(
|
||||
await cv.async_validate(hass, component.CONFIG_SCHEMA, config), []
|
||||
)
|
||||
except vol.Invalid as exc:
|
||||
exc_info = ConfigExceptionInfo(
|
||||
exc,
|
||||
|
@ -1570,7 +1572,9 @@ async def async_process_component_config(
|
|||
# Validate component specific platform schema
|
||||
platform_path = f"{p_name}.{domain}"
|
||||
try:
|
||||
p_validated = component_platform_schema(p_config)
|
||||
p_validated = await cv.async_validate(
|
||||
hass, component_platform_schema, p_config
|
||||
)
|
||||
except vol.Invalid as exc:
|
||||
exc_info = ConfigExceptionInfo(
|
||||
exc,
|
||||
|
|
|
@ -234,7 +234,7 @@ async def async_check_ha_config_file( # noqa: C901
|
|||
config_schema = getattr(component, "CONFIG_SCHEMA", None)
|
||||
if config_schema is not None:
|
||||
try:
|
||||
validated_config = config_schema(config)
|
||||
validated_config = await cv.async_validate(hass, config_schema, config)
|
||||
# Don't fail if the validator removed the domain from the config
|
||||
if domain in validated_config:
|
||||
result[domain] = validated_config[domain]
|
||||
|
@ -255,7 +255,9 @@ async def async_check_ha_config_file( # noqa: C901
|
|||
for p_name, p_config in config_per_platform(config, domain):
|
||||
# Validate component specific platform schema
|
||||
try:
|
||||
p_validated = component_platform_schema(p_config)
|
||||
p_validated = await cv.async_validate(
|
||||
hass, component_platform_schema, p_config
|
||||
)
|
||||
except vol.Invalid as ex:
|
||||
_comp_error(ex, domain, p_config, p_config)
|
||||
continue
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from collections.abc import Callable, Hashable
|
||||
import contextlib
|
||||
from contextvars import ContextVar
|
||||
from datetime import (
|
||||
date as date_sys,
|
||||
datetime as datetime_sys,
|
||||
|
@ -13,6 +14,7 @@ from datetime import (
|
|||
timedelta,
|
||||
)
|
||||
from enum import Enum, StrEnum
|
||||
import functools
|
||||
import logging
|
||||
from numbers import Number
|
||||
import os
|
||||
|
@ -20,6 +22,7 @@ import re
|
|||
from socket import ( # type: ignore[attr-defined] # private, not in typeshed
|
||||
_GLOBAL_DEFAULT_TIMEOUT,
|
||||
)
|
||||
import threading
|
||||
from typing import Any, cast, overload
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
@ -94,6 +97,7 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import (
|
||||
DOMAIN as HOMEASSISTANT_DOMAIN,
|
||||
HomeAssistant,
|
||||
async_get_hass,
|
||||
async_get_hass_or_none,
|
||||
split_entity_id,
|
||||
|
@ -114,6 +118,51 @@ from .typing import VolDictType, VolSchemaType
|
|||
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'"
|
||||
|
||||
|
||||
class MustValidateInExecutor(HomeAssistantError):
|
||||
"""Raised when validation must happen in an executor thread."""
|
||||
|
||||
|
||||
class _Hass(threading.local):
|
||||
"""Container which makes a HomeAssistant instance available to validators."""
|
||||
|
||||
hass: HomeAssistant | None = None
|
||||
|
||||
|
||||
_hass = _Hass()
|
||||
"""Set when doing async friendly schema validation."""
|
||||
|
||||
|
||||
def _async_get_hass_or_none() -> HomeAssistant | None:
|
||||
"""Return the HomeAssistant instance or None.
|
||||
|
||||
First tries core.async_get_hass_or_none, then _hass which is
|
||||
set when doing async friendly schema validation.
|
||||
"""
|
||||
return async_get_hass_or_none() or _hass.hass
|
||||
|
||||
|
||||
_validating_async: ContextVar[bool] = ContextVar("_validating_async", default=False)
|
||||
"""Set to True when doing async friendly schema validation."""
|
||||
|
||||
|
||||
def not_async_friendly[**_P, _R](validator: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""Mark a validator as not async friendly.
|
||||
|
||||
This makes validation happen in an executor thread if validation is done by
|
||||
async_validate, otherwise does nothing.
|
||||
"""
|
||||
|
||||
@functools.wraps(validator)
|
||||
def _not_async_friendly(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _validating_async.get() and async_get_hass_or_none():
|
||||
# Raise if doing async friendly validation and validation
|
||||
# is happening in the event loop
|
||||
raise MustValidateInExecutor
|
||||
return validator(*args, **kwargs)
|
||||
|
||||
return _not_async_friendly
|
||||
|
||||
|
||||
class UrlProtocolSchema(StrEnum):
|
||||
"""Valid URL protocol schema values."""
|
||||
|
||||
|
@ -217,6 +266,7 @@ def whitespace(value: Any) -> str:
|
|||
raise vol.Invalid(f"contains non-whitespace: {value}")
|
||||
|
||||
|
||||
@not_async_friendly
|
||||
def isdevice(value: Any) -> str:
|
||||
"""Validate that value is a real device."""
|
||||
try:
|
||||
|
@ -258,6 +308,7 @@ def is_regex(value: Any) -> re.Pattern[Any]:
|
|||
return r
|
||||
|
||||
|
||||
@not_async_friendly
|
||||
def isfile(value: Any) -> str:
|
||||
"""Validate that the value is an existing file."""
|
||||
if value is None:
|
||||
|
@ -271,6 +322,7 @@ def isfile(value: Any) -> str:
|
|||
return file_in
|
||||
|
||||
|
||||
@not_async_friendly
|
||||
def isdir(value: Any) -> str:
|
||||
"""Validate that the value is an existing dir."""
|
||||
if value is None:
|
||||
|
@ -664,7 +716,7 @@ def template(value: Any | None) -> template_helper.Template:
|
|||
if isinstance(value, (list, dict, template_helper.Template)):
|
||||
raise vol.Invalid("template value should be a string")
|
||||
|
||||
template_value = template_helper.Template(str(value), async_get_hass_or_none())
|
||||
template_value = template_helper.Template(str(value), _async_get_hass_or_none())
|
||||
|
||||
try:
|
||||
template_value.ensure_valid()
|
||||
|
@ -682,7 +734,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
|
|||
if not template_helper.is_template_string(str(value)):
|
||||
raise vol.Invalid("template value does not contain a dynamic template")
|
||||
|
||||
template_value = template_helper.Template(str(value), async_get_hass_or_none())
|
||||
template_value = template_helper.Template(str(value), _async_get_hass_or_none())
|
||||
|
||||
try:
|
||||
template_value.ensure_valid()
|
||||
|
@ -1918,3 +1970,32 @@ historic_currency = vol.In(
|
|||
country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country")
|
||||
|
||||
language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language")
|
||||
|
||||
|
||||
async def async_validate(
|
||||
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
|
||||
) -> Any:
|
||||
"""Async friendly schema validation.
|
||||
|
||||
If a validator decorated with @not_async_friendly is called, validation will be
|
||||
deferred to an executor. If not, validation will happen in the event loop.
|
||||
"""
|
||||
_validating_async.set(True)
|
||||
try:
|
||||
return validator(value)
|
||||
except MustValidateInExecutor:
|
||||
return await hass.async_add_executor_job(
|
||||
_validate_in_executor, hass, validator, value
|
||||
)
|
||||
finally:
|
||||
_validating_async.set(False)
|
||||
|
||||
|
||||
def _validate_in_executor(
|
||||
hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
|
||||
) -> Any:
|
||||
_hass.hass = hass
|
||||
try:
|
||||
return validator(value)
|
||||
finally:
|
||||
_hass.hass = None
|
||||
|
|
|
@ -3,13 +3,16 @@
|
|||
from collections import OrderedDict
|
||||
from datetime import date, datetime, timedelta
|
||||
import enum
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
import uuid
|
||||
|
||||
import py
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -1738,3 +1741,67 @@ def test_determine_script_action_ambiguous() -> None:
|
|||
def test_determine_script_action_non_ambiguous() -> None:
|
||||
"""Test determine script action with a non ambiguous action."""
|
||||
assert cv.determine_script_action({"delay": "00:00:05"}) == "delay"
|
||||
|
||||
|
||||
async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> None:
|
||||
"""Test the async_validate helper."""
|
||||
validator_calls: dict[str, list[int]] = {}
|
||||
|
||||
def _mock_validator_schema(real_func, *args):
|
||||
calls = validator_calls.setdefault(real_func.__name__, [])
|
||||
calls.append(threading.get_ident())
|
||||
return real_func(*args)
|
||||
|
||||
CV_PREFIX = "homeassistant.helpers.config_validation"
|
||||
with (
|
||||
patch(f"{CV_PREFIX}.isdir", wraps=partial(_mock_validator_schema, cv.isdir)),
|
||||
patch(f"{CV_PREFIX}.string", wraps=partial(_mock_validator_schema, cv.string)),
|
||||
):
|
||||
# Assert validation in event loop when not decorated with not_async_friendly
|
||||
await cv.async_validate(hass, cv.string, "abcd")
|
||||
assert validator_calls == {"string": [hass.loop_thread_id]}
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in executor when decorated with not_async_friendly
|
||||
await cv.async_validate(hass, cv.isdir, tmpdir)
|
||||
assert validator_calls == {"isdir": [hass.loop_thread_id, ANY]}
|
||||
assert validator_calls["isdir"][1] != hass.loop_thread_id
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in executor when decorated with not_async_friendly
|
||||
await cv.async_validate(hass, vol.All(cv.isdir, cv.string), tmpdir)
|
||||
assert validator_calls == {"isdir": [hass.loop_thread_id, ANY], "string": [ANY]}
|
||||
assert validator_calls["isdir"][1] != hass.loop_thread_id
|
||||
assert validator_calls["string"][0] != hass.loop_thread_id
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in executor when decorated with not_async_friendly
|
||||
await cv.async_validate(hass, vol.All(cv.string, cv.isdir), tmpdir)
|
||||
assert validator_calls == {
|
||||
"isdir": [hass.loop_thread_id, ANY],
|
||||
"string": [hass.loop_thread_id, ANY],
|
||||
}
|
||||
assert validator_calls["isdir"][1] != hass.loop_thread_id
|
||||
assert validator_calls["string"][1] != hass.loop_thread_id
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in event loop when not using cv.async_validate
|
||||
cv.isdir(tmpdir)
|
||||
assert validator_calls == {"isdir": [hass.loop_thread_id]}
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in event loop when not using cv.async_validate
|
||||
vol.All(cv.isdir, cv.string)(tmpdir)
|
||||
assert validator_calls == {
|
||||
"isdir": [hass.loop_thread_id],
|
||||
"string": [hass.loop_thread_id],
|
||||
}
|
||||
validator_calls = {}
|
||||
|
||||
# Assert validation in event loop when not using cv.async_validate
|
||||
vol.All(cv.string, cv.isdir)(tmpdir)
|
||||
assert validator_calls == {
|
||||
"isdir": [hass.loop_thread_id],
|
||||
"string": [hass.loop_thread_id],
|
||||
}
|
||||
validator_calls = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue