From 10feac11d960d1a6e8883ef8e6b01945adb6777c Mon Sep 17 00:00:00 2001 From: Lewis Juggins Date: Wed, 12 Oct 2016 11:05:41 +0100 Subject: [PATCH] Support recursive config inclusions (#3783) --- homeassistant/util/yaml.py | 33 ++++++++----- tests/util/test_yaml.py | 95 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 12 deletions(-) diff --git a/homeassistant/util/yaml.py b/homeassistant/util/yaml.py index 035a96b657e..cf773bb999f 100644 --- a/homeassistant/util/yaml.py +++ b/homeassistant/util/yaml.py @@ -1,8 +1,8 @@ """YAML utility functions.""" -import glob import logging import os import sys +import fnmatch from collections import OrderedDict from typing import Union, List, Dict @@ -61,23 +61,32 @@ def _include_yaml(loader: SafeLineLoader, return load_yaml(fname) +def _find_files(directory, pattern): + """Recursively load files in a directory.""" + for root, _dirs, files in os.walk(directory): + for basename in files: + if fnmatch.fnmatch(basename, pattern): + filename = os.path.join(root, basename) + yield filename + + def _include_dir_named_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> OrderedDict: """Load multiple files from directory as a dictionary.""" mapping = OrderedDict() # type: OrderedDict - files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') - for fname in glob.glob(files): + loc = os.path.join(os.path.dirname(loader.name), node.value) + for fname in _find_files(loc, '*.yaml'): filename = os.path.splitext(os.path.basename(fname))[0] mapping[filename] = load_yaml(fname) return mapping def _include_dir_merge_named_yaml(loader: SafeLineLoader, - node: yaml.nodes.Node): + node: yaml.nodes.Node) -> OrderedDict: """Load multiple files from directory as a merged dictionary.""" mapping = OrderedDict() # type: OrderedDict - files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') - for fname in glob.glob(files): + loc = os.path.join(os.path.dirname(loader.name), node.value) + for fname in _find_files(loc, '*.yaml'): if os.path.basename(fname) == _SECRET_YAML: continue loaded_yaml = load_yaml(fname) @@ -89,18 +98,18 @@ def _include_dir_merge_named_yaml(loader: SafeLineLoader, def _include_dir_list_yaml(loader: SafeLineLoader, node: yaml.nodes.Node): """Load multiple files from directory as a list.""" - files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') - return [load_yaml(f) for f in glob.glob(files) + loc = os.path.join(os.path.dirname(loader.name), node.value) + return [load_yaml(f) for f in _find_files(loc, '*.yaml') if os.path.basename(f) != _SECRET_YAML] def _include_dir_merge_list_yaml(loader: SafeLineLoader, node: yaml.nodes.Node): """Load multiple files from directory as a merged list.""" - files = os.path.join(os.path.dirname(loader.name), - node.value, '*.yaml') # type: str + loc = os.path.join(os.path.dirname(loader.name), + node.value) # type: str merged_list = [] # type: List - for fname in glob.glob(files): + for fname in _find_files(loc, '*.yaml'): if os.path.basename(fname) == _SECRET_YAML: continue loaded_yaml = load_yaml(fname) diff --git a/tests/util/test_yaml.py b/tests/util/test_yaml.py index 6b35e4f844c..b1214c2ff17 100644 --- a/tests/util/test_yaml.py +++ b/tests/util/test_yaml.py @@ -92,6 +92,27 @@ class TestYaml(unittest.TestCase): doc = yaml.yaml.safe_load(file) assert sorted(doc["key"]) == sorted(["one", "two"]) + def test_include_dir_list_recursive(self): + """Test include dir recursive list yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_0 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_0.write(b"zero") + file_0.close() + temp_dir = tempfile.TemporaryDirectory(dir=include_dir) + file_1 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_1.write(b"one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_2.write(b"two") + file_2.close() + conf = "key: !include_dir_list {}".format(include_dir) + with io.StringIO(conf) as file: + doc = yaml.yaml.safe_load(file) + assert sorted(doc["key"]) == sorted(["zero", "one", "two"]) + def test_include_dir_named(self): """Test include dir named yaml.""" with tempfile.TemporaryDirectory() as include_dir: @@ -111,6 +132,32 @@ class TestYaml(unittest.TestCase): doc = yaml.yaml.safe_load(file) assert doc["key"] == correct + def test_include_dir_named_recursive(self): + """Test include dir named yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_0 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_0.write(b"zero") + file_0.close() + temp_dir = tempfile.TemporaryDirectory(dir=include_dir) + file_1 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_1.write(b"one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_2.write(b"two") + file_2.close() + conf = "key: !include_dir_named {}".format(include_dir) + correct = {} + correct[os.path.splitext( + os.path.basename(file_0.name))[0]] = "zero" + correct[os.path.splitext(os.path.basename(file_1.name))[0]] = "one" + correct[os.path.splitext(os.path.basename(file_2.name))[0]] = "two" + with io.StringIO(conf) as file: + doc = yaml.yaml.safe_load(file) + assert doc["key"] == correct + def test_include_dir_merge_list(self): """Test include dir merge list yaml.""" with tempfile.TemporaryDirectory() as include_dir: @@ -127,6 +174,28 @@ class TestYaml(unittest.TestCase): doc = yaml.yaml.safe_load(file) assert sorted(doc["key"]) == sorted(["one", "two", "three"]) + def test_include_dir_merge_list_recursive(self): + """Test include dir merge list yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_0 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_0.write(b"- zero") + file_0.close() + temp_dir = tempfile.TemporaryDirectory(dir=include_dir) + file_1 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_1.write(b"- one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_2.write(b"- two\n- three") + file_2.close() + conf = "key: !include_dir_merge_list {}".format(include_dir) + with io.StringIO(conf) as file: + doc = yaml.yaml.safe_load(file) + assert sorted(doc["key"]) == sorted(["zero", "one", "two", + "three"]) + def test_include_dir_merge_named(self): """Test include dir merge named yaml.""" with tempfile.TemporaryDirectory() as include_dir: @@ -147,6 +216,32 @@ class TestYaml(unittest.TestCase): "key3": "three" } + def test_include_dir_merge_named_recursive(self): + """Test include dir merge named yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_0 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_0.write(b"key0: zero") + file_0.close() + temp_dir = tempfile.TemporaryDirectory(dir=include_dir) + file_1 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_1.write(b"key1: one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=temp_dir.name, + suffix=".yaml", delete=False) + file_2.write(b"key2: two\nkey3: three") + file_2.close() + conf = "key: !include_dir_merge_named {}".format(include_dir) + with io.StringIO(conf) as file: + doc = yaml.yaml.safe_load(file) + assert doc["key"] == { + "key0": "zero", + "key1": "one", + "key2": "two", + "key3": "three" + } + FILES = {}