Add script + extra config validators

* Add config validation and extra validators

* Address PR comments
This commit is contained in:
Paulus Schoutsen 2016-04-03 10:19:09 -07:00
parent 7ed5055fa2
commit e140e9b8ab
9 changed files with 404 additions and 88 deletions

View file

@ -95,18 +95,18 @@ def setup(hass, config):
'demo': {
'alias': 'Toggle {}'.format(lights[0].split('.')[1]),
'sequence': [{
'execute_service': 'light.turn_off',
'service_data': {ATTR_ENTITY_ID: lights[0]}
'service': 'light.turn_off',
'data': {ATTR_ENTITY_ID: lights[0]}
}, {
'delay': {'seconds': 5}
}, {
'execute_service': 'light.turn_on',
'service_data': {ATTR_ENTITY_ID: lights[0]}
'service': 'light.turn_on',
'data': {ATTR_ENTITY_ID: lights[0]}
}, {
'delay': {'seconds': 5}
}, {
'execute_service': 'light.turn_off',
'service_data': {ATTR_ENTITY_ID: lights[0]}
'service': 'light.turn_off',
'data': {ATTR_ENTITY_ID: lights[0]}
}]
}}})

View file

@ -123,7 +123,7 @@ MEDIA_PLAYER_SCHEMA = vol.Schema({
})
MEDIA_PLAYER_MUTE_VOLUME_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({
vol.Required(ATTR_MEDIA_VOLUME_MUTED): vol.Coerce(bool),
vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean,
})
MEDIA_PLAYER_SET_VOLUME_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({

View file

@ -72,8 +72,7 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({
vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, 'payload'): cv.string,
vol.Required(ATTR_QOS, default=DEFAULT_QOS):
vol.All(vol.Coerce(int), vol.In([0, 1, 2])),
# pylint: disable=no-value-for-parameter
vol.Required(ATTR_RETAIN, default=DEFAULT_RETAIN): vol.Boolean(),
vol.Required(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
}, required=True)

View file

@ -12,6 +12,8 @@ import threading
from datetime import timedelta
from itertools import islice
import voluptuous as vol
import homeassistant.util.dt as date_util
from homeassistant.const import (
ATTR_ENTITY_ID, EVENT_TIME_CHANGED, SERVICE_TURN_OFF, SERVICE_TURN_ON,
@ -21,7 +23,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import track_point_in_utc_time
from homeassistant.helpers.service import (call_from_config,
validate_service_call)
from homeassistant.util import slugify
import homeassistant.helpers.config_validation as cv
DOMAIN = "script"
ENTITY_ID_FORMAT = DOMAIN + '.{}'
@ -42,6 +44,62 @@ ATTR_CAN_CANCEL = 'can_cancel'
_LOGGER = logging.getLogger(__name__)
_ALIAS_VALIDATOR = vol.Schema(cv.string)
def _alias_stripper(validator):
"""Strip alias from object for validation."""
def validate(value):
"""Validate without alias value."""
value = value.copy()
alias = value.pop(CONF_ALIAS, None)
if alias is not None:
alias = _ALIAS_VALIDATOR(alias)
value = validator(value)
if alias is not None:
value[CONF_ALIAS] = alias
return value
return validate
_DELAY_SCHEMA = {
vol.Required(CONF_DELAY): vol.All({
CONF_ALIAS: cv.string,
'days': vol.All(vol.Coerce(int), vol.Range(min=0)),
'seconds': vol.All(vol.Coerce(int), vol.Range(min=0)),
'microseconds': vol.All(vol.Coerce(int), vol.Range(min=0)),
'milliseconds': vol.All(vol.Coerce(int), vol.Range(min=0)),
'minutes': vol.All(vol.Coerce(int), vol.Range(min=0)),
'hours': vol.All(vol.Coerce(int), vol.Range(min=0)),
'weeks': vol.All(vol.Coerce(int), vol.Range(min=0)),
}, cv.has_at_least_one_key([
'days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours',
'weeks']))
}
_EVENT_SCHEMA = cv.EVENT_SCHEMA.extend({
CONF_ALIAS: cv.string,
})
_SCRIPT_ENTRY_SCHEMA = vol.Schema({
CONF_ALIAS: cv.string,
vol.Required(CONF_SEQUENCE): vol.All(vol.Length(min=1), [vol.Any(
_EVENT_SCHEMA,
_DELAY_SCHEMA,
# Can't extend SERVICE_SCHEMA because it is an vol.All
_alias_stripper(cv.SERVICE_SCHEMA),
)]),
})
CONFIG_SCHEMA = vol.Schema({
vol.Required(DOMAIN): cv.DictValidator(_SCRIPT_ENTRY_SCHEMA, cv.slug)
}, extra=vol.ALLOW_EXTRA)
def is_on(hass, entity_id):
"""Return if the switch is on based on the statemachine."""
@ -73,22 +131,12 @@ def setup(hass, config):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.entities.get(entity_id)
if not script:
return
if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id)
return
script.turn_on()
for object_id, cfg in config[DOMAIN].items():
if object_id != slugify(object_id):
_LOGGER.warning("Found invalid key for script: %s. Use %s instead",
object_id, slugify(object_id))
continue
if not isinstance(cfg.get(CONF_SEQUENCE), list):
_LOGGER.warning("Key 'sequence' for script %s should be a list",
object_id)
continue
alias = cfg.get(CONF_ALIAS, object_id)
script = Script(object_id, alias, cfg[CONF_SEQUENCE])
component.add_entities((script,))

View file

@ -1,22 +1,34 @@
"""Helpers for config validation using voluptuous."""
import jinja2
import voluptuous as vol
from homeassistant.const import (
CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELCIUS, TEMP_FAHRENHEIT)
from homeassistant.helpers.entity import valid_entity_id
import homeassistant.util.dt as dt_util
from homeassistant.util import slugify
# pylint: disable=invalid-name
PLATFORM_SCHEMA = vol.Schema({
vol.Required(CONF_PLATFORM): str,
CONF_SCAN_INTERVAL: vol.All(vol.Coerce(int), vol.Range(min=1)),
}, extra=vol.ALLOW_EXTRA)
# Home Assistant types
byte = vol.All(vol.Coerce(int), vol.Range(min=0, max=255))
small_float = vol.All(vol.Coerce(float), vol.Range(min=0, max=1))
latitude = vol.All(vol.Coerce(float), vol.Range(min=-90, max=90))
longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180))
latitude = vol.All(vol.Coerce(float), vol.Range(min=-90, max=90),
msg='invalid latitude')
longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180),
msg='invalid longitude')
def boolean(value):
"""Validate and coerce a boolean value."""
if isinstance(value, str):
if value in ('1', 'true', 'yes', 'on', 'enable'):
return True
if value in ('0', 'false', 'no', 'off', 'disable'):
return False
raise vol.Invalid('invalid boolean value {}'.format(value))
return bool(value)
def entity_id(value):
@ -48,22 +60,54 @@ def icon(value):
raise vol.Invalid('Icons should start with prefix "mdi:"')
def service(value):
"""Validate service."""
# Services use same format as entities so we can use same helper.
if valid_entity_id(value):
return value
raise vol.Invalid('Service {} does not match format <domain>.<name>'
.format(value))
def slug(value):
"""Validate value is a valid slug."""
if value is None:
raise vol.Invalid('Slug should not be None')
value = str(value)
slg = slugify(value)
if value == slg:
return value
raise vol.Invalid('invalid slug {} (try {})'.format(value, slg))
def string(value):
"""Coerce value to string, except for None."""
if value is not None:
return str(value)
raise vol.Invalid('Value should not be None')
raise vol.Invalid('string value is None')
def temperature_unit(value):
"""Validate and transform temperature unit."""
if isinstance(value, str):
value = value.upper()
if value == 'C':
return TEMP_CELCIUS
elif value == 'F':
return TEMP_FAHRENHEIT
raise vol.Invalid('Invalid temperature unit. Expected: C or F')
value = str(value).upper()
if value == 'C':
return TEMP_CELCIUS
elif value == 'F':
return TEMP_FAHRENHEIT
raise vol.Invalid('invalid temperature unit (expected C or F)')
def template(value):
"""Validate a jinja2 template."""
if value is None:
raise vol.Invalid('template value is None')
value = str(value)
try:
jinja2.Environment().parse(value)
return value
except jinja2.exceptions.TemplateSyntaxError as ex:
raise vol.Invalid('invalid template ({})'.format(ex))
def time_zone(value):
@ -73,3 +117,93 @@ def time_zone(value):
raise vol.Invalid(
'Invalid time zone passed in. Valid options can be found here: '
'http://en.wikipedia.org/wiki/List_of_tz_database_time_zones')
# Validator helpers
# pylint: disable=too-few-public-methods
class DictValidator(object):
"""Validate keys and values in a dictionary."""
def __init__(self, value_validator=None, key_validator=None):
"""Initialize the dict validator."""
if value_validator is not None:
value_validator = vol.Schema(value_validator)
self.value_validator = value_validator
if key_validator is not None:
key_validator = vol.Schema(key_validator)
self.key_validator = key_validator
def __call__(self, obj):
"""Validate the dict."""
if not isinstance(obj, dict):
raise vol.Invalid('Expected dictionary.')
errors = []
# So we keep it an OrderedDict if it is one
result = obj.__class__()
for key, value in obj.items():
if self.key_validator is not None:
try:
key = self.key_validator(key)
except vol.Invalid as ex:
errors.append('key {} is invalid ({})'.format(key, ex))
if self.value_validator is not None:
try:
value = self.value_validator(value)
except vol.Invalid as ex:
errors.append(
'key {} contains invalid value ({})'.format(key, ex))
if not errors:
result[key] = value
if errors:
raise vol.Invalid(
'invalid dictionary: {}'.format(', '.join(errors)))
return result
# Adapted from:
# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666
def has_at_least_one_key(keys):
"""Validator that at least one key exists."""
def validate(obj):
"""Test keys exist in dict."""
if not isinstance(obj, dict):
raise vol.Invalid('expected dictionary')
for k in obj.keys():
if k in keys:
return obj
raise vol.Invalid('must contain one of {}.'.format(', '.join(keys)))
return validate
# Schemas
PLATFORM_SCHEMA = vol.Schema({
vol.Required(CONF_PLATFORM): string,
CONF_SCAN_INTERVAL: vol.All(vol.Coerce(int), vol.Range(min=1)),
}, extra=vol.ALLOW_EXTRA)
EVENT_SCHEMA = vol.Schema({
vol.Required('event'): string,
'event_data': dict
})
SERVICE_SCHEMA = vol.All(vol.Schema({
vol.Exclusive('service', 'service name'): service,
vol.Exclusive('service_template', 'service name'): string,
vol.Exclusive('data', 'service data'): dict,
vol.Exclusive('data_template', 'service data'): DictValidator(template),
}), has_at_least_one_key(['service', 'service_template']))

View file

@ -5,7 +5,7 @@ pytz>=2016.3
pip>=7.0.0
vincenty==0.1.4
jinja2>=2.8
voluptuous==0.8.9
voluptuous==0.8.10
# homeassistant.components.isy994
PyISY==1.0.5

View file

@ -17,7 +17,7 @@ REQUIRES = [
'pip>=7.0.0',
'vincenty==0.1.4',
'jinja2>=2.8',
'voluptuous==0.8.9',
'voluptuous==0.8.10',
]
setup(

View file

@ -3,6 +3,7 @@
from datetime import timedelta
import unittest
from homeassistant.bootstrap import _setup_component
from homeassistant.components import script
import homeassistant.util.dt as dt_util
@ -18,46 +19,43 @@ class TestScript(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.hass.config.components.append('group')
def tearDown(self): # pylint: disable=invalid-name
"""Stop down everything that was started."""
self.hass.stop()
def test_setup_with_missing_sequence(self):
"""Test setup with missing sequence."""
self.assertTrue(script.setup(self.hass, {
'script': {
'test': {}
}
}))
self.assertEqual(0, len(self.hass.states.entity_ids('script')))
def test_setup_with_invalid_object_id(self):
"""Test setup with invalid objects."""
self.assertTrue(script.setup(self.hass, {
'script': {
def test_setup_with_invalid_configs(self):
"""Test setup with invalid configs."""
for value in (
{'test': {}},
{
'test hello world': {
'sequence': []
}
}
}))
self.assertEqual(0, len(self.hass.states.entity_ids('script')))
def test_setup_with_dict_as_sequence(self):
"""Test setup with dict as sequence."""
self.assertTrue(script.setup(self.hass, {
'script': {
},
{
'test': {
'sequence': {
'event': 'test_event'
}
}
}
}))
},
{
'test': {
'sequence': {
'event': 'test_event',
'service': 'homeassistant.turn_on',
}
}
},
self.assertEqual(0, len(self.hass.states.entity_ids('script')))
):
assert not _setup_component(self.hass, 'script', {
'script': value
}), 'Script loaded with wrong config {}'.format(value)
self.assertEqual(0, len(self.hass.states.entity_ids('script')))
def test_firing_event(self):
"""Test the firing of events."""
@ -70,7 +68,7 @@ class TestScript(unittest.TestCase):
self.hass.bus.listen(event, record_event)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'alias': 'Test Script',
@ -82,7 +80,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -102,7 +100,7 @@ class TestScript(unittest.TestCase):
self.hass.services.register('test', 'script', record_call)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -113,7 +111,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -131,7 +129,7 @@ class TestScript(unittest.TestCase):
self.hass.services.register('test', 'script', record_call)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -153,7 +151,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -172,7 +170,7 @@ class TestScript(unittest.TestCase):
self.hass.bus.listen(event, record_event)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -186,7 +184,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -219,7 +217,7 @@ class TestScript(unittest.TestCase):
self.hass.bus.listen(event, record_event)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -231,7 +229,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -263,7 +261,7 @@ class TestScript(unittest.TestCase):
self.hass.bus.listen(event, record_event)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -275,7 +273,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.turn_on(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()
@ -298,7 +296,7 @@ class TestScript(unittest.TestCase):
self.hass.bus.listen(event, record_event)
self.assertTrue(script.setup(self.hass, {
assert _setup_component(self.hass, 'script', {
'script': {
'test': {
'sequence': [{
@ -310,7 +308,7 @@ class TestScript(unittest.TestCase):
}]
}
}
}))
})
script.toggle(self.hass, ENTITY_ID)
self.hass.pool.block_till_done()

View file

@ -28,21 +28,9 @@ def test_longitude():
schema(value)
def test_icon():
"""Test icon validation."""
schema = vol.Schema(cv.icon)
for value in (False, 'work', 'icon:work'):
with pytest.raises(vol.MultipleInvalid):
schema(value)
schema('mdi:work')
def test_platform_config():
"""Test platform config validation."""
for value in (
{'platform': 1},
{},
{'hello': 'world'},
):
@ -92,6 +80,103 @@ def test_entity_ids():
]
def test_event_schema():
"""Test event_schema validation."""
for value in (
{}, None,
{
'event_data': {},
},
{
'event': 'state_changed',
'event_data': 1,
},
):
with pytest.raises(vol.MultipleInvalid):
cv.EVENT_SCHEMA(value)
for value in (
{'event': 'state_changed'},
{'event': 'state_changed', 'event_data': {'hello': 'world'}},
):
cv.EVENT_SCHEMA(value)
def test_icon():
"""Test icon validation."""
schema = vol.Schema(cv.icon)
for value in (False, 'work', 'icon:work'):
with pytest.raises(vol.MultipleInvalid):
schema(value)
schema('mdi:work')
def test_service():
"""Test service validation."""
schema = vol.Schema(cv.service)
with pytest.raises(vol.MultipleInvalid):
schema('invalid_turn_on')
schema('homeassistant.turn_on')
def test_service_schema():
"""Test service_schema validation."""
for value in (
{}, None,
{
'service': 'homeassistant.turn_on',
'service_template': 'homeassistant.turn_on'
},
{
'data': {'entity_id': 'light.kitchen'},
},
{
'service': 'homeassistant.turn_on',
'data': None
},
{
'service': 'homeassistant.turn_on',
'data_template': {
'brightness': '{{ no_end'
}
},
):
with pytest.raises(vol.MultipleInvalid):
cv.SERVICE_SCHEMA(value)
for value in (
{'service': 'homeassistant.turn_on'},
):
cv.SERVICE_SCHEMA(value)
def test_slug():
"""Test slug validation."""
schema = vol.Schema(cv.slug)
for value in (None, 'hello world'):
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in (12345, 'hello'):
schema(value)
def test_string():
"""Test string validation."""
schema = vol.Schema(cv.string)
with pytest.raises(vol.MultipleInvalid):
schema(None)
for value in (True, 1, 'hello'):
schema(value)
def test_temperature_unit():
"""Test temperature unit validation."""
schema = vol.Schema(cv.temperature_unit)
@ -103,6 +188,24 @@ def test_temperature_unit():
schema('F')
def test_template():
"""Test template validator."""
schema = vol.Schema(cv.template)
for value in (
None, '{{ partial_print }', '{% if True %}Hello'
):
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in (
1, 'Hello',
'{{ beer }}',
'{% if 1 == 1 %}Hello{% else %}World{% endif %}',
):
schema(value)
def test_time_zone():
"""Test time zone validation."""
schema = vol.Schema(cv.time_zone)
@ -112,3 +215,37 @@ def test_time_zone():
schema('America/Los_Angeles')
schema('UTC')
def test_dict_validator():
"""Test DictValidator."""
schema = vol.Schema(cv.DictValidator(cv.entity_ids, cv.slug))
for value in (
None,
{'invalid slug': 'sensor.temp'},
{'hello world': 'invalid_entity'}
):
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in (
{},
{'hello_world': 'sensor.temp'},
):
schema(value)
assert schema({'hello_world': 'sensor.temp'}) == \
{'hello_world': ['sensor.temp']}
def test_has_at_least_one_key():
"""Test has_at_least_one_key validator."""
schema = vol.Schema(cv.has_at_least_one_key(['beer', 'soda']))
for value in (None, [], {}, {'wine': None}):
with pytest.raises(vol.MultipleInvalid):
schema(value)
for value in ({'beer': None}, {'soda': None}):
schema(value)