Improve yaml fault tolerance and handle check_config border cases (#3159)

This commit is contained in:
Johann Kellerman 2016-09-08 22:20:38 +02:00 committed by GitHub
parent 267cda447e
commit e8ad76c816
5 changed files with 90 additions and 59 deletions

View file

@ -221,14 +221,18 @@ def check(config_path):
try: try:
bootstrap.from_config_file(config_path, skip_pip=True) bootstrap.from_config_file(config_path, skip_pip=True)
res['secret_cache'] = yaml.__SECRET_CACHE res['secret_cache'] = dict(yaml.__SECRET_CACHE)
return res except Exception as err: # pylint: disable=broad-except
print(color('red', 'Fatal error while loading config:'), str(err))
finally: finally:
# Stop all patches # Stop all patches
for pat in PATCHES.values(): for pat in PATCHES.values():
pat.stop() pat.stop()
# Ensure !secrets point to the original function # Ensure !secrets point to the original function
yaml.yaml.SafeLoader.add_constructor('!secret', yaml._secret_yaml) yaml.yaml.SafeLoader.add_constructor('!secret', yaml._secret_yaml)
bootstrap.clear_secret_cache()
return res
def dump_dict(layer, indent_count=1, listi=False, **kwargs): def dump_dict(layer, indent_count=1, listi=False, **kwargs):

View file

@ -121,6 +121,16 @@ def _ordered_dict(loader: SafeLineLoader,
line = getattr(node, '__line__', 'unknown') line = getattr(node, '__line__', 'unknown')
if line != 'unknown' and (min_line is None or line < min_line): if line != 'unknown' and (min_line is None or line < min_line):
min_line = line min_line = line
try:
hash(key)
except TypeError:
fname = getattr(loader.stream, 'name', '')
raise yaml.MarkedYAMLError(
context="invalid key: \"{}\"".format(key),
context_mark=yaml.Mark(fname, 0, min_line, -1, None, None)
)
if key in seen: if key in seen:
fname = getattr(loader.stream, 'name', '') fname = getattr(loader.stream, 'name', '')
first_mark = yaml.Mark(fname, 0, seen[key], -1, None, None) first_mark = yaml.Mark(fname, 0, seen[key], -1, None, None)

View file

@ -247,20 +247,23 @@ def patch_yaml_files(files_dict, endswith=True):
"""Patch load_yaml with a dictionary of yaml files.""" """Patch load_yaml with a dictionary of yaml files."""
# match using endswith, start search with longest string # match using endswith, start search with longest string
matchlist = sorted(list(files_dict.keys()), key=len) if endswith else [] matchlist = sorted(list(files_dict.keys()), key=len) if endswith else []
# matchlist.sort(key=len)
def mock_open_f(fname, **_): def mock_open_f(fname, **_):
"""Mock open() in the yaml module, used by load_yaml.""" """Mock open() in the yaml module, used by load_yaml."""
# Return the mocked file on full match # Return the mocked file on full match
if fname in files_dict: if fname in files_dict:
_LOGGER.debug('patch_yaml_files match %s', fname) _LOGGER.debug('patch_yaml_files match %s', fname)
return StringIO(files_dict[fname]) res = StringIO(files_dict[fname])
setattr(res, 'name', fname)
return res
# Match using endswith # Match using endswith
for ends in matchlist: for ends in matchlist:
if fname.endswith(ends): if fname.endswith(ends):
_LOGGER.debug('patch_yaml_files end match %s: %s', ends, fname) _LOGGER.debug('patch_yaml_files end match %s: %s', ends, fname)
return StringIO(files_dict[ends]) res = StringIO(files_dict[ends])
setattr(res, 'name', fname)
return res
# Fallback for hass.components (i.e. services.yaml) # Fallback for hass.components (i.e. services.yaml)
if 'homeassistant/components' in fname: if 'homeassistant/components' in fname:
@ -268,6 +271,6 @@ def patch_yaml_files(files_dict, endswith=True):
return open(fname, encoding='utf-8') return open(fname, encoding='utf-8')
# Not found # Not found
raise IOError('File not found: {}'.format(fname)) raise FileNotFoundError('File not found: {}'.format(fname))
return patch.object(yaml, 'open', mock_open_f, create=True) return patch.object(yaml, 'open', mock_open_f, create=True)

View file

@ -137,7 +137,10 @@ class TestCheckConfig(unittest.TestCase):
self.maxDiff = None self.maxDiff = None
with patch_yaml_files(files): with patch_yaml_files(files):
res = check_config.check(get_test_config_dir('secret.yaml')) config_path = get_test_config_dir('secret.yaml')
secrets_path = get_test_config_dir('secrets.yaml')
res = check_config.check(config_path)
change_yaml_files(res) change_yaml_files(res)
# convert secrets OrderedDict to dict for assertequal # convert secrets OrderedDict to dict for assertequal
@ -148,7 +151,7 @@ class TestCheckConfig(unittest.TestCase):
'components': {'http': {'api_password': 'abc123', 'components': {'http': {'api_password': 'abc123',
'server_port': 8123}}, 'server_port': 8123}},
'except': {}, 'except': {},
'secret_cache': {'secrets.yaml': {'http_pw': 'abc123'}}, 'secret_cache': {secrets_path: {'http_pw': 'abc123'}},
'secrets': {'http_pw': 'abc123'}, 'secrets': {'http_pw': 'abc123'},
'yaml_files': ['.../secret.yaml', 'secrets.yaml'] 'yaml_files': ['.../secret.yaml', '.../secrets.yaml']
}, res) }, res)

View file

@ -3,59 +3,68 @@ import io
import unittest import unittest
import os import os
import tempfile import tempfile
from unittest.mock import patch
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml from homeassistant.util import yaml
import homeassistant.config as config_util from homeassistant.config import YAML_CONFIG_FILE, load_yaml_config_file
from tests.common import get_test_config_dir from tests.common import get_test_config_dir, patch_yaml_files
class TestYaml(unittest.TestCase): class TestYaml(unittest.TestCase):
"""Test util.yaml loader.""" """Test util.yaml loader."""
# pylint: disable=no-self-use,invalid-name
def test_simple_list(self): def test_simple_list(self):
"""Test simple list.""" """Test simple list."""
conf = "config:\n - simple\n - list" conf = "config:\n - simple\n - list"
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc['config'] == ["simple", "list"] assert doc['config'] == ["simple", "list"]
def test_simple_dict(self): def test_simple_dict(self):
"""Test simple dict.""" """Test simple dict."""
conf = "key: value" conf = "key: value"
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc['key'] == 'value' assert doc['key'] == 'value'
def test_duplicate_key(self): def test_duplicate_key(self):
"""Test simple dict.""" """Test duplicate dict keys."""
conf = "key: thing1\nkey: thing2" files = {YAML_CONFIG_FILE: 'key: thing1\nkey: thing2'}
try: with self.assertRaises(HomeAssistantError):
with io.StringIO(conf) as f: with patch_yaml_files(files):
yaml.yaml.safe_load(f) load_yaml_config_file(YAML_CONFIG_FILE)
except Exception:
pass def test_unhashable_key(self):
else: """Test an unhasable key."""
assert 0 files = {YAML_CONFIG_FILE: 'message:\n {{ states.state }}'}
with self.assertRaises(HomeAssistantError), \
patch_yaml_files(files):
load_yaml_config_file(YAML_CONFIG_FILE)
def test_no_key(self):
"""Test item without an key."""
files = {YAML_CONFIG_FILE: 'a: a\nnokeyhere'}
with self.assertRaises(HomeAssistantError), \
patch_yaml_files(files):
yaml.load_yaml(YAML_CONFIG_FILE)
def test_enviroment_variable(self): def test_enviroment_variable(self):
"""Test config file with enviroment variable.""" """Test config file with enviroment variable."""
os.environ["PASSWORD"] = "secret_password" os.environ["PASSWORD"] = "secret_password"
conf = "password: !env_var PASSWORD" conf = "password: !env_var PASSWORD"
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc['password'] == "secret_password" assert doc['password'] == "secret_password"
del os.environ["PASSWORD"] del os.environ["PASSWORD"]
def test_invalid_enviroment_variable(self): def test_invalid_enviroment_variable(self):
"""Test config file with no enviroment variable sat.""" """Test config file with no enviroment variable sat."""
conf = "password: !env_var PASSWORD" conf = "password: !env_var PASSWORD"
try: with self.assertRaises(HomeAssistantError):
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
yaml.yaml.safe_load(f) yaml.yaml.safe_load(file)
except Exception:
pass
else:
assert 0
def test_include_yaml(self): def test_include_yaml(self):
"""Test include yaml.""" """Test include yaml."""
@ -63,8 +72,8 @@ class TestYaml(unittest.TestCase):
include_file.write(b"value") include_file.write(b"value")
include_file.seek(0) include_file.seek(0)
conf = "key: !include {}".format(include_file.name) conf = "key: !include {}".format(include_file.name)
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc["key"] == "value" assert doc["key"] == "value"
def test_include_dir_list(self): def test_include_dir_list(self):
@ -79,8 +88,8 @@ class TestYaml(unittest.TestCase):
file_2.write(b"two") file_2.write(b"two")
file_2.close() file_2.close()
conf = "key: !include_dir_list {}".format(include_dir) conf = "key: !include_dir_list {}".format(include_dir)
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert sorted(doc["key"]) == sorted(["one", "two"]) assert sorted(doc["key"]) == sorted(["one", "two"])
def test_include_dir_named(self): def test_include_dir_named(self):
@ -98,8 +107,8 @@ class TestYaml(unittest.TestCase):
correct = {} correct = {}
correct[os.path.splitext(os.path.basename(file_1.name))[0]] = "one" correct[os.path.splitext(os.path.basename(file_1.name))[0]] = "one"
correct[os.path.splitext(os.path.basename(file_2.name))[0]] = "two" correct[os.path.splitext(os.path.basename(file_2.name))[0]] = "two"
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc["key"] == correct assert doc["key"] == correct
def test_include_dir_merge_list(self): def test_include_dir_merge_list(self):
@ -114,8 +123,8 @@ class TestYaml(unittest.TestCase):
file_2.write(b"- two\n- three") file_2.write(b"- two\n- three")
file_2.close() file_2.close()
conf = "key: !include_dir_merge_list {}".format(include_dir) conf = "key: !include_dir_merge_list {}".format(include_dir)
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert sorted(doc["key"]) == sorted(["one", "two", "three"]) assert sorted(doc["key"]) == sorted(["one", "two", "three"])
def test_include_dir_merge_named(self): def test_include_dir_merge_named(self):
@ -130,23 +139,25 @@ class TestYaml(unittest.TestCase):
file_2.write(b"key2: two\nkey3: three") file_2.write(b"key2: two\nkey3: three")
file_2.close() file_2.close()
conf = "key: !include_dir_merge_named {}".format(include_dir) conf = "key: !include_dir_merge_named {}".format(include_dir)
with io.StringIO(conf) as f: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(f) doc = yaml.yaml.safe_load(file)
assert doc["key"] == { assert doc["key"] == {
"key1": "one", "key1": "one",
"key2": "two", "key2": "two",
"key3": "three" "key3": "three"
} }
FILES = {}
def load_yaml(fname, string): def load_yaml(fname, string):
"""Write a string to file and return the parsed yaml.""" """Write a string to file and return the parsed yaml."""
with open(fname, 'w') as file: FILES[fname] = string
file.write(string) with patch_yaml_files(FILES):
return config_util.load_yaml_config_file(fname) return load_yaml_config_file(fname)
class FakeKeyring(): class FakeKeyring(): # pylint: disable=too-few-public-methods
"""Fake a keyring class.""" """Fake a keyring class."""
def __init__(self, secrets_dict): def __init__(self, secrets_dict):
@ -162,20 +173,16 @@ class FakeKeyring():
class TestSecrets(unittest.TestCase): class TestSecrets(unittest.TestCase):
"""Test the secrets parameter in the yaml utility.""" """Test the secrets parameter in the yaml utility."""
# pylint: disable=protected-access,invalid-name
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Create & load secrets file.""" """Create & load secrets file."""
config_dir = get_test_config_dir() config_dir = get_test_config_dir()
yaml.clear_secret_cache() yaml.clear_secret_cache()
self._yaml_path = os.path.join(config_dir, self._yaml_path = os.path.join(config_dir, YAML_CONFIG_FILE)
config_util.YAML_CONFIG_FILE)
self._secret_path = os.path.join(config_dir, yaml._SECRET_YAML) self._secret_path = os.path.join(config_dir, yaml._SECRET_YAML)
self._sub_folder_path = os.path.join(config_dir, 'subFolder') self._sub_folder_path = os.path.join(config_dir, 'subFolder')
if not os.path.exists(self._sub_folder_path):
os.makedirs(self._sub_folder_path)
self._unrelated_path = os.path.join(config_dir, 'unrelated') self._unrelated_path = os.path.join(config_dir, 'unrelated')
if not os.path.exists(self._unrelated_path):
os.makedirs(self._unrelated_path)
load_yaml(self._secret_path, load_yaml(self._secret_path,
'http_pw: pwhttp\n' 'http_pw: pwhttp\n'
@ -194,12 +201,7 @@ class TestSecrets(unittest.TestCase):
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Clean up secrets.""" """Clean up secrets."""
yaml.clear_secret_cache() yaml.clear_secret_cache()
for path in [self._yaml_path, self._secret_path, FILES.clear()
os.path.join(self._sub_folder_path, 'sub.yaml'),
os.path.join(self._sub_folder_path, yaml._SECRET_YAML),
os.path.join(self._unrelated_path, yaml._SECRET_YAML)]:
if os.path.isfile(path):
os.remove(path)
def test_secrets_from_yaml(self): def test_secrets_from_yaml(self):
"""Did secrets load ok.""" """Did secrets load ok."""
@ -263,3 +265,12 @@ class TestSecrets(unittest.TestCase):
"""Ensure logger: debug was removed.""" """Ensure logger: debug was removed."""
with self.assertRaises(yaml.HomeAssistantError): with self.assertRaises(yaml.HomeAssistantError):
load_yaml(self._yaml_path, 'api_password: !secret logger') load_yaml(self._yaml_path, 'api_password: !secret logger')
@patch('homeassistant.util.yaml._LOGGER.error')
def test_bad_logger_value(self, mock_error):
"""Ensure logger: debug was removed."""
yaml.clear_secret_cache()
load_yaml(self._secret_path, 'logger: info\npw: abc')
load_yaml(self._yaml_path, 'api_password: !secret pw')
assert mock_error.call_count == 1, \
"Expected an error about logger: value"