Migrate tests to pytest (#23544)

* Migrate tests to pytest

* Fixup

* Use loop fixture in test_check_config

* Lint
This commit is contained in:
Erik Montnemery 2019-04-30 18:20:38 +02:00 committed by Paulus Schoutsen
parent d71424f285
commit 407e0c58f9
25 changed files with 4744 additions and 4910 deletions

View file

@ -1,25 +1,24 @@
"""The tests for the Entity component helper.""" """The tests for the Entity component helper."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
from collections import OrderedDict from collections import OrderedDict
import logging import logging
import unittest
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
from datetime import timedelta from datetime import timedelta
import asynctest
import pytest import pytest
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
from homeassistant.components import group from homeassistant.components import group
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, mock_coro, MockPlatform, MockModule, mock_coro,
async_fire_time_changed, MockEntity, MockConfigEntry, async_fire_time_changed, MockEntity, MockConfigEntry,
mock_entity_platform, mock_integration) mock_entity_platform, mock_integration)
@ -27,63 +26,51 @@ _LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain" DOMAIN = "test_domain"
class TestHelpersEntityComponent(unittest.TestCase): async def test_setting_up_group(hass):
"""Test homeassistant.helpers.entity_component module."""
def setUp(self): # pylint: disable=invalid-name
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name
"""Clean up the test Home Assistant instance."""
self.hass.stop()
def test_setting_up_group(self):
"""Set up the setting of a group.""" """Set up the setting of a group."""
setup_component(self.hass, 'group', {'group': {}}) assert await async_setup_component(hass, 'group', {'group': {}})
component = EntityComponent(_LOGGER, DOMAIN, self.hass, component = EntityComponent(_LOGGER, DOMAIN, hass,
group_name='everyone') group_name='everyone')
# No group after setup # No group after setup
assert len(self.hass.states.entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0
component.add_entities([MockEntity()]) await component.async_add_entities([MockEntity()])
self.hass.block_till_done() await hass.async_block_till_done()
# group exists # group exists
assert len(self.hass.states.entity_ids()) == 2 assert len(hass.states.async_entity_ids()) == 2
assert self.hass.states.entity_ids('group') == ['group.everyone'] assert hass.states.async_entity_ids('group') == ['group.everyone']
group = self.hass.states.get('group.everyone') grp = hass.states.get('group.everyone')
assert group.attributes.get('entity_id') == \ assert grp.attributes.get('entity_id') == \
('test_domain.unnamed_device',) ('test_domain.unnamed_device',)
# group extended # group extended
component.add_entities([MockEntity(name='goodbye')]) await component.async_add_entities([MockEntity(name='goodbye')])
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.hass.states.entity_ids()) == 3 assert len(hass.states.async_entity_ids()) == 3
group = self.hass.states.get('group.everyone') grp = hass.states.get('group.everyone')
# Ordered in order of added to the group # Ordered in order of added to the group
assert group.attributes.get('entity_id') == \ assert grp.attributes.get('entity_id') == \
('test_domain.goodbye', 'test_domain.unnamed_device') ('test_domain.goodbye', 'test_domain.unnamed_device')
def test_setup_loads_platforms(self):
async def test_setup_loads_platforms(hass):
"""Test the loading of the platforms.""" """Test the loading of the platforms."""
component_setup = Mock(return_value=True) component_setup = Mock(return_value=True)
platform_setup = Mock(return_value=None) platform_setup = Mock(return_value=None)
mock_integration(self.hass, mock_integration(hass, MockModule('test_component', setup=component_setup))
MockModule('test_component', setup=component_setup))
# mock the dependencies # mock the dependencies
mock_integration(self.hass, mock_integration(hass, MockModule('mod2', dependencies=['test_component']))
MockModule('mod2', dependencies=['test_component'])) mock_entity_platform(hass, 'test_domain.mod2',
mock_entity_platform(self.hass, 'test_domain.mod2',
MockPlatform(platform_setup)) MockPlatform(platform_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
assert not component_setup.called assert not component_setup.called
assert not platform_setup.called assert not platform_setup.called
@ -94,21 +81,22 @@ class TestHelpersEntityComponent(unittest.TestCase):
} }
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert component_setup.called assert component_setup.called
assert platform_setup.called assert platform_setup.called
def test_setup_recovers_when_setup_raises(self):
async def test_setup_recovers_when_setup_raises(hass):
"""Test the setup if exceptions are happening.""" """Test the setup if exceptions are happening."""
platform1_setup = Mock(side_effect=Exception('Broken')) platform1_setup = Mock(side_effect=Exception('Broken'))
platform2_setup = Mock(return_value=None) platform2_setup = Mock(return_value=None)
mock_entity_platform(self.hass, 'test_domain.mod1', mock_entity_platform(hass, 'test_domain.mod1',
MockPlatform(platform1_setup)) MockPlatform(platform1_setup))
mock_entity_platform(self.hass, 'test_domain.mod2', mock_entity_platform(hass, 'test_domain.mod2',
MockPlatform(platform2_setup)) MockPlatform(platform2_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
assert not platform1_setup.called assert not platform1_setup.called
assert not platform2_setup.called assert not platform2_setup.called
@ -119,41 +107,43 @@ class TestHelpersEntityComponent(unittest.TestCase):
("{} 3".format(DOMAIN), {'platform': 'mod2'}), ("{} 3".format(DOMAIN), {'platform': 'mod2'}),
])) ]))
self.hass.block_till_done() await hass.async_block_till_done()
assert platform1_setup.called assert platform1_setup.called
assert platform2_setup.called assert platform2_setup.called
@patch('homeassistant.helpers.entity_component.EntityComponent'
@asynctest.patch('homeassistant.helpers.entity_component.EntityComponent'
'._async_setup_platform', return_value=mock_coro()) '._async_setup_platform', return_value=mock_coro())
@patch('homeassistant.setup.async_setup_component', @asynctest.patch('homeassistant.setup.async_setup_component',
return_value=mock_coro(True)) return_value=mock_coro(True))
def test_setup_does_discovery(self, mock_setup_component, mock_setup): async def test_setup_does_discovery(mock_setup_component, mock_setup, hass):
"""Test setup for discovery.""" """Test setup for discovery."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
component.setup({}) component.setup({})
discovery.load_platform(self.hass, DOMAIN, 'platform_test', discovery.load_platform(hass, DOMAIN, 'platform_test',
{'msg': 'discovery_info'}, {DOMAIN: {}}) {'msg': 'discovery_info'}, {DOMAIN: {}})
self.hass.block_till_done() await hass.async_block_till_done()
assert mock_setup.called assert mock_setup.called
assert ('platform_test', {}, {'msg': 'discovery_info'}) == \ assert ('platform_test', {}, {'msg': 'discovery_info'}) == \
mock_setup.call_args[0] mock_setup.call_args[0]
@patch('homeassistant.helpers.entity_platform.'
@asynctest.patch('homeassistant.helpers.entity_platform.'
'async_track_time_interval') 'async_track_time_interval')
def test_set_scan_interval_via_config(self, mock_track): async def test_set_scan_interval_via_config(mock_track, hass):
"""Test the setting of the scan interval via configuration.""" """Test the setting of the scan interval via configuration."""
def platform_setup(hass, config, add_entities, discovery_info=None): def platform_setup(hass, config, add_entities, discovery_info=None):
"""Test the platform setup.""" """Test the platform setup."""
add_entities([MockEntity(should_poll=True)]) add_entities([MockEntity(should_poll=True)])
mock_entity_platform(self.hass, 'test_domain.platform', mock_entity_platform(hass, 'test_domain.platform',
MockPlatform(platform_setup)) MockPlatform(platform_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
component.setup({ component.setup({
DOMAIN: { DOMAIN: {
@ -162,11 +152,12 @@ class TestHelpersEntityComponent(unittest.TestCase):
} }
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert mock_track.called assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2] assert timedelta(seconds=30) == mock_track.call_args[0][2]
def test_set_entity_namespace_via_config(self):
async def test_set_entity_namespace_via_config(hass):
"""Test setting an entity namespace.""" """Test setting an entity namespace."""
def platform_setup(hass, config, add_entities, discovery_info=None): def platform_setup(hass, config, add_entities, discovery_info=None):
"""Test the platform setup.""" """Test the platform setup."""
@ -177,9 +168,9 @@ class TestHelpersEntityComponent(unittest.TestCase):
platform = MockPlatform(platform_setup) platform = MockPlatform(platform_setup)
mock_entity_platform(self.hass, 'test_domain.platform', platform) mock_entity_platform(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
component.setup({ component.setup({
DOMAIN: { DOMAIN: {
@ -188,17 +179,16 @@ class TestHelpersEntityComponent(unittest.TestCase):
} }
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert sorted(self.hass.states.entity_ids()) == \ assert sorted(hass.states.async_entity_ids()) == \
['test_domain.yummy_beer', 'test_domain.yummy_unnamed_device'] ['test_domain.yummy_beer', 'test_domain.yummy_unnamed_device']
@asyncio.coroutine async def test_extract_from_service_available_device(hass):
def test_extract_from_service_available_device(hass):
"""Test the extraction of entity from service and device is available.""" """Test the extraction of entity from service and device is available."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='test_1'), MockEntity(name='test_1'),
MockEntity(name='test_2', available=False), MockEntity(name='test_2', available=False),
MockEntity(name='test_3'), MockEntity(name='test_3'),
@ -209,7 +199,7 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_1', 'test_domain.test_3'] == \ assert ['test_domain.test_1', 'test_domain.test_3'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
(yield from component.async_extract_from_service(call_1))) (await component.async_extract_from_service(call_1)))
call_2 = ha.ServiceCall('test', 'service', data={ call_2 = ha.ServiceCall('test', 'service', data={
'entity_id': ['test_domain.test_3', 'test_domain.test_4'], 'entity_id': ['test_domain.test_3', 'test_domain.test_4'],
@ -217,11 +207,10 @@ def test_extract_from_service_available_device(hass):
assert ['test_domain.test_3'] == \ assert ['test_domain.test_3'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
(yield from component.async_extract_from_service(call_2))) (await component.async_extract_from_service(call_2)))
@asyncio.coroutine async def test_platform_not_ready(hass):
def test_platform_not_ready(hass):
"""Test that we retry when platform not ready.""" """Test that we retry when platform not ready."""
platform1_setup = Mock(side_effect=[PlatformNotReady, PlatformNotReady, platform1_setup = Mock(side_effect=[PlatformNotReady, PlatformNotReady,
None]) None])
@ -231,7 +220,7 @@ def test_platform_not_ready(hass):
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_setup({ await component.async_setup({
DOMAIN: { DOMAIN: {
'platform': 'mod1' 'platform': 'mod1'
} }
@ -245,32 +234,31 @@ def test_platform_not_ready(hass):
with patch('homeassistant.util.dt.utcnow', return_value=utcnow): with patch('homeassistant.util.dt.utcnow', return_value=utcnow):
# Should not trigger attempt 2 # Should not trigger attempt 2
async_fire_time_changed(hass, utcnow + timedelta(seconds=29)) async_fire_time_changed(hass, utcnow + timedelta(seconds=29))
yield from hass.async_block_till_done() await hass.async_block_till_done()
assert len(platform1_setup.mock_calls) == 1 assert len(platform1_setup.mock_calls) == 1
# Should trigger attempt 2 # Should trigger attempt 2
async_fire_time_changed(hass, utcnow + timedelta(seconds=30)) async_fire_time_changed(hass, utcnow + timedelta(seconds=30))
yield from hass.async_block_till_done() await hass.async_block_till_done()
assert len(platform1_setup.mock_calls) == 2 assert len(platform1_setup.mock_calls) == 2
assert 'test_domain.mod1' not in hass.config.components assert 'test_domain.mod1' not in hass.config.components
# This should not trigger attempt 3 # This should not trigger attempt 3
async_fire_time_changed(hass, utcnow + timedelta(seconds=59)) async_fire_time_changed(hass, utcnow + timedelta(seconds=59))
yield from hass.async_block_till_done() await hass.async_block_till_done()
assert len(platform1_setup.mock_calls) == 2 assert len(platform1_setup.mock_calls) == 2
# Trigger attempt 3, which succeeds # Trigger attempt 3, which succeeds
async_fire_time_changed(hass, utcnow + timedelta(seconds=60)) async_fire_time_changed(hass, utcnow + timedelta(seconds=60))
yield from hass.async_block_till_done() await hass.async_block_till_done()
assert len(platform1_setup.mock_calls) == 3 assert len(platform1_setup.mock_calls) == 3
assert 'test_domain.mod1' in hass.config.components assert 'test_domain.mod1' in hass.config.components
@asyncio.coroutine async def test_extract_from_service_returns_all_if_no_entity_id(hass):
def test_extract_from_service_returns_all_if_no_entity_id(hass):
"""Test the extraction of everything from service.""" """Test the extraction of everything from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='test_1'), MockEntity(name='test_1'),
MockEntity(name='test_2'), MockEntity(name='test_2'),
]) ])
@ -279,14 +267,13 @@ def test_extract_from_service_returns_all_if_no_entity_id(hass):
assert ['test_domain.test_1', 'test_domain.test_2'] == \ assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in sorted(ent.entity_id for ent in
(yield from component.async_extract_from_service(call))) (await component.async_extract_from_service(call)))
@asyncio.coroutine async def test_extract_from_service_filter_out_non_existing_entities(hass):
def test_extract_from_service_filter_out_non_existing_entities(hass):
"""Test the extraction of non existing entities from service.""" """Test the extraction of non existing entities from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='test_1'), MockEntity(name='test_1'),
MockEntity(name='test_2'), MockEntity(name='test_2'),
]) ])
@ -297,28 +284,26 @@ def test_extract_from_service_filter_out_non_existing_entities(hass):
assert ['test_domain.test_2'] == \ assert ['test_domain.test_2'] == \
[ent.entity_id for ent [ent.entity_id for ent
in (yield from component.async_extract_from_service(call))] in await component.async_extract_from_service(call)]
@asyncio.coroutine async def test_extract_from_service_no_group_expand(hass):
def test_extract_from_service_no_group_expand(hass):
"""Test not expanding a group.""" """Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
test_group = yield from group.Group.async_create_group( test_group = await group.Group.async_create_group(
hass, 'test_group', ['light.Ceiling', 'light.Kitchen']) hass, 'test_group', ['light.Ceiling', 'light.Kitchen'])
yield from component.async_add_entities([test_group]) await component.async_add_entities([test_group])
call = ha.ServiceCall('test', 'service', { call = ha.ServiceCall('test', 'service', {
'entity_id': ['group.test_group'] 'entity_id': ['group.test_group']
}) })
extracted = yield from component.async_extract_from_service( extracted = await component.async_extract_from_service(
call, expand_group=False) call, expand_group=False)
assert extracted == [test_group] assert extracted == [test_group]
@asyncio.coroutine async def test_setup_dependencies_platform(hass):
def test_setup_dependencies_platform(hass):
"""Test we setup the dependencies of a platform. """Test we setup the dependencies of a platform.
We're explictely testing that we process dependencies even if a component We're explictely testing that we process dependencies even if a component
@ -331,7 +316,7 @@ def test_setup_dependencies_platform(hass):
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_setup({ await component.async_setup({
DOMAIN: { DOMAIN: {
'platform': 'test_component', 'platform': 'test_component',
} }
@ -355,7 +340,7 @@ async def test_setup_entry(hass):
assert await component.async_setup_entry(entry) assert await component.async_setup_entry(entry)
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
p_hass, p_entry, p_add_entities = mock_setup_entry.mock_calls[0][1] p_hass, p_entry, _ = mock_setup_entry.mock_calls[0][1]
assert p_hass is hass assert p_hass is hass
assert p_entry is entry assert p_entry is entry
@ -448,7 +433,7 @@ async def test_set_service_race(hass):
await async_setup_component(hass, 'group', {}) await async_setup_component(hass, 'group', {})
component = EntityComponent(_LOGGER, DOMAIN, hass, group_name='yo') component = EntityComponent(_LOGGER, DOMAIN, hass, group_name='yo')
for i in range(2): for _ in range(2):
hass.async_create_task(component.async_add_entities([MockEntity()])) hass.async_create_task(component.async_add_entities([MockEntity()]))
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -1,14 +1,14 @@
"""Tests for the EntityPlatform helper.""" """Tests for the EntityPlatform helper."""
import asyncio import asyncio
import logging import logging
import unittest
from unittest.mock import patch, Mock, MagicMock from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta from datetime import timedelta
import asynctest
import pytest import pytest
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers.entity import generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_component import ( from homeassistant.helpers.entity_component import (
EntityComponent, DEFAULT_SCAN_INTERVAL) EntityComponent, DEFAULT_SCAN_INTERVAL)
from homeassistant.helpers import entity_platform, entity_registry from homeassistant.helpers import entity_platform, entity_registry
@ -16,7 +16,7 @@ from homeassistant.helpers import entity_platform, entity_registry
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, MockPlatform, async_fire_time_changed, mock_registry,
MockEntity, MockEntityPlatform, MockConfigEntry, mock_entity_platform) MockEntity, MockEntityPlatform, MockConfigEntry, mock_entity_platform)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -24,42 +24,32 @@ DOMAIN = "test_domain"
PLATFORM = 'test_platform' PLATFORM = 'test_platform'
class TestHelpersEntityPlatform(unittest.TestCase): async def test_polling_only_updates_entities_it_should_poll(hass):
"""Test homeassistant.helpers.entity_component module."""
def setUp(self): # pylint: disable=invalid-name
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name
"""Clean up the test Home Assistant instance."""
self.hass.stop()
def test_polling_only_updates_entities_it_should_poll(self):
"""Test the polling of only updated entities.""" """Test the polling of only updated entities."""
component = EntityComponent( component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) _LOGGER, DOMAIN, hass, timedelta(seconds=20))
no_poll_ent = MockEntity(should_poll=False) no_poll_ent = MockEntity(should_poll=False)
no_poll_ent.async_update = Mock() no_poll_ent.async_update = Mock()
poll_ent = MockEntity(should_poll=True) poll_ent = MockEntity(should_poll=True)
poll_ent.async_update = Mock() poll_ent.async_update = Mock()
component.add_entities([no_poll_ent, poll_ent]) await component.async_add_entities([no_poll_ent, poll_ent])
no_poll_ent.async_update.reset_mock() no_poll_ent.async_update.reset_mock()
poll_ent.async_update.reset_mock() poll_ent.async_update.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done() await hass.async_block_till_done()
assert not no_poll_ent.async_update.called assert not no_poll_ent.async_update.called
assert poll_ent.async_update.called assert poll_ent.async_update.called
def test_polling_updates_entities_with_exception(self):
async def test_polling_updates_entities_with_exception(hass):
"""Test the updated entities that not break with an exception.""" """Test the updated entities that not break with an exception."""
component = EntityComponent( component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) _LOGGER, DOMAIN, hass, timedelta(seconds=20))
update_ok = [] update_ok = []
update_err = [] update_err = []
@ -82,64 +72,68 @@ class TestHelpersEntityPlatform(unittest.TestCase):
ent4 = MockEntity(should_poll=True) ent4 = MockEntity(should_poll=True)
ent4.update = update_mock ent4.update = update_mock
component.add_entities([ent1, ent2, ent3, ent4]) await component.async_add_entities([ent1, ent2, ent3, ent4])
update_ok.clear() update_ok.clear()
update_err.clear() update_err.clear()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(update_ok) == 3 assert len(update_ok) == 3
assert len(update_err) == 1 assert len(update_err) == 1
def test_update_state_adds_entities(self):
async def test_update_state_adds_entities(hass):
"""Test if updating poll entities cause an entity to be added works.""" """Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
ent1 = MockEntity() ent1 = MockEntity()
ent2 = MockEntity(should_poll=True) ent2 = MockEntity(should_poll=True)
component.add_entities([ent2]) await component.async_add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids()) assert len(hass.states.async_entity_ids()) == 1
ent2.update = lambda *_: component.add_entities([ent1]) ent2.update = lambda *_: component.add_entities([ent1])
fire_time_changed( async_fire_time_changed(
self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL
) )
self.hass.block_till_done() await hass.async_block_till_done()
assert 2 == len(self.hass.states.entity_ids()) assert len(hass.states.async_entity_ids()) == 2
def test_update_state_adds_entities_with_update_before_add_true(self):
async def test_update_state_adds_entities_with_update_before_add_true(hass):
"""Test if call update before add to state machine.""" """Test if call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
ent = MockEntity() ent = MockEntity()
ent.update = Mock(spec_set=True) ent.update = Mock(spec_set=True)
component.add_entities([ent], True) await component.async_add_entities([ent], True)
self.hass.block_till_done() await hass.async_block_till_done()
assert 1 == len(self.hass.states.entity_ids()) assert len(hass.states.async_entity_ids()) == 1
assert ent.update.called assert ent.update.called
def test_update_state_adds_entities_with_update_before_add_false(self):
async def test_update_state_adds_entities_with_update_before_add_false(hass):
"""Test if not call update before add to state machine.""" """Test if not call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
ent = MockEntity() ent = MockEntity()
ent.update = Mock(spec_set=True) ent.update = Mock(spec_set=True)
component.add_entities([ent], False) await component.async_add_entities([ent], False)
self.hass.block_till_done() await hass.async_block_till_done()
assert 1 == len(self.hass.states.entity_ids()) assert len(hass.states.async_entity_ids()) == 1
assert not ent.update.called assert not ent.update.called
@patch('homeassistant.helpers.entity_platform.'
@asynctest.patch('homeassistant.helpers.entity_platform.'
'async_track_time_interval') 'async_track_time_interval')
def test_set_scan_interval_via_platform(self, mock_track): async def test_set_scan_interval_via_platform(mock_track, hass):
"""Test the setting of the scan interval via platform.""" """Test the setting of the scan interval via platform."""
def platform_setup(hass, config, add_entities, discovery_info=None): def platform_setup(hass, config, add_entities, discovery_info=None):
"""Test the platform setup.""" """Test the platform setup."""
@ -148,9 +142,9 @@ class TestHelpersEntityPlatform(unittest.TestCase):
platform = MockPlatform(platform_setup) platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = timedelta(seconds=30) platform.SCAN_INTERVAL = timedelta(seconds=30)
mock_entity_platform(self.hass, 'test_domain.platform', platform) mock_entity_platform(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
component.setup({ component.setup({
DOMAIN: { DOMAIN: {
@ -158,30 +152,30 @@ class TestHelpersEntityPlatform(unittest.TestCase):
} }
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert mock_track.called assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2] assert timedelta(seconds=30) == mock_track.call_args[0][2]
def test_adding_entities_with_generator_and_thread_callback(self):
async def test_adding_entities_with_generator_and_thread_callback(hass):
"""Test generator in add_entities that calls thread method. """Test generator in add_entities that calls thread method.
We should make sure we resolve the generator to a list before passing We should make sure we resolve the generator to a list before passing
it into an async context. it into an async context.
""" """
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
def create_entity(number): def create_entity(number):
"""Create entity helper.""" """Create entity helper."""
entity = MockEntity() entity = MockEntity()
entity.entity_id = generate_entity_id(DOMAIN + '.{}', entity.entity_id = async_generate_entity_id(DOMAIN + '.{}',
'Number', hass=self.hass) 'Number', hass=hass)
return entity return entity
component.add_entities(create_entity(i) for i in range(2)) await component.async_add_entities(create_entity(i) for i in range(2))
@asyncio.coroutine async def test_platform_warn_slow_setup(hass):
def test_platform_warn_slow_setup(hass):
"""Warn we log when platform setup takes a long time.""" """Warn we log when platform setup takes a long time."""
platform = MockPlatform() platform = MockPlatform()
@ -191,7 +185,7 @@ def test_platform_warn_slow_setup(hass):
with patch.object(hass.loop, 'call_later', MagicMock()) \ with patch.object(hass.loop, 'call_later', MagicMock()) \
as mock_call: as mock_call:
yield from component.async_setup({ await component.async_setup({
DOMAIN: { DOMAIN: {
'platform': 'platform', 'platform': 'platform',
} }
@ -208,21 +202,19 @@ def test_platform_warn_slow_setup(hass):
assert mock_call().cancel.called assert mock_call().cancel.called
@asyncio.coroutine async def test_platform_error_slow_setup(hass, caplog):
def test_platform_error_slow_setup(hass, caplog):
"""Don't block startup more than SLOW_SETUP_MAX_WAIT.""" """Don't block startup more than SLOW_SETUP_MAX_WAIT."""
with patch.object(entity_platform, 'SLOW_SETUP_MAX_WAIT', 0): with patch.object(entity_platform, 'SLOW_SETUP_MAX_WAIT', 0):
called = [] called = []
@asyncio.coroutine async def setup_platform(*args):
def setup_platform(*args):
called.append(1) called.append(1)
yield from asyncio.sleep(1, loop=hass.loop) await asyncio.sleep(1, loop=hass.loop)
platform = MockPlatform(async_setup_platform=setup_platform) platform = MockPlatform(async_setup_platform=setup_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
mock_entity_platform(hass, 'test_domain.test_platform', platform) mock_entity_platform(hass, 'test_domain.test_platform', platform)
yield from component.async_setup({ await component.async_setup({
DOMAIN: { DOMAIN: {
'platform': 'test_platform', 'platform': 'test_platform',
} }
@ -232,23 +224,21 @@ def test_platform_error_slow_setup(hass, caplog):
assert 'test_platform is taking longer than 0 seconds' in caplog.text assert 'test_platform is taking longer than 0 seconds' in caplog.text
@asyncio.coroutine async def test_updated_state_used_for_entity_id(hass):
def test_updated_state_used_for_entity_id(hass):
"""Test that first update results used for entity ID generation.""" """Test that first update results used for entity ID generation."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
class MockEntityNameFetcher(MockEntity): class MockEntityNameFetcher(MockEntity):
"""Mock entity that fetches a friendly name.""" """Mock entity that fetches a friendly name."""
@asyncio.coroutine async def async_update(self):
def async_update(self):
"""Mock update that assigns a name.""" """Mock update that assigns a name."""
self._values['name'] = "Living Room" self._values['name'] = "Living Room"
yield from component.async_add_entities([MockEntityNameFetcher()], True) await component.async_add_entities([MockEntityNameFetcher()], True)
entity_ids = hass.states.async_entity_ids() entity_ids = hass.states.async_entity_ids()
assert 1 == len(entity_ids) assert len(entity_ids) == 1
assert entity_ids[0] == "test_domain.living_room" assert entity_ids[0] == "test_domain.living_room"
@ -374,8 +364,7 @@ async def test_parallel_updates_sync_platform_with_constant(hass):
assert entity.parallel_updates._value == 2 assert entity.parallel_updates._value == 2
@asyncio.coroutine async def test_raise_error_on_update(hass):
def test_raise_error_on_update(hass):
"""Test the add entity if they raise an error on update.""" """Test the add entity if they raise an error on update."""
updates = [] updates = []
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
@ -389,63 +378,58 @@ def test_raise_error_on_update(hass):
entity1.update = _raise entity1.update = _raise
entity2.update = lambda: updates.append(1) entity2.update = lambda: updates.append(1)
yield from component.async_add_entities([entity1, entity2], True) await component.async_add_entities([entity1, entity2], True)
assert len(updates) == 1 assert len(updates) == 1
assert 1 in updates assert 1 in updates
@asyncio.coroutine async def test_async_remove_with_platform(hass):
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform.""" """Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(name='test_1') entity1 = MockEntity(name='test_1')
yield from component.async_add_entities([entity1]) await component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1 assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove() await entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0
@asyncio.coroutine async def test_not_adding_duplicate_entities_with_unique_id(hass):
def test_not_adding_duplicate_entities_with_unique_id(hass):
"""Test for not adding duplicate entities.""" """Test for not adding duplicate entities."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='test1', unique_id='not_very_unique')]) MockEntity(name='test1', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1 assert len(hass.states.async_entity_ids()) == 1
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='test2', unique_id='not_very_unique')]) MockEntity(name='test2', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1 assert len(hass.states.async_entity_ids()) == 1
@asyncio.coroutine async def test_using_prescribed_entity_id(hass):
def test_using_prescribed_entity_id(hass):
"""Test for using predefined entity ID.""" """Test for using predefined entity ID."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='bla', entity_id='hello.world')]) MockEntity(name='bla', entity_id='hello.world')])
assert 'hello.world' in hass.states.async_entity_ids() assert 'hello.world' in hass.states.async_entity_ids()
@asyncio.coroutine async def test_using_prescribed_entity_id_with_unique_id(hass):
def test_using_prescribed_entity_id_with_unique_id(hass):
"""Test for ammending predefined entity ID because currently exists.""" """Test for ammending predefined entity ID because currently exists."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(entity_id='test_domain.world')]) MockEntity(entity_id='test_domain.world')])
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(entity_id='test_domain.world', unique_id='bla')]) MockEntity(entity_id='test_domain.world', unique_id='bla')])
assert 'test_domain.world_2' in hass.states.async_entity_ids() assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine async def test_using_prescribed_entity_id_which_is_registered(hass):
def test_using_prescribed_entity_id_which_is_registered(hass):
"""Test not allowing predefined entity ID that already registered.""" """Test not allowing predefined entity ID that already registered."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass) registry = mock_registry(hass)
@ -454,14 +438,13 @@ def test_using_prescribed_entity_id_which_is_registered(hass):
DOMAIN, 'test', '1234', suggested_object_id='world') DOMAIN, 'test', '1234', suggested_object_id='world')
# This entity_id will be rewritten # This entity_id will be rewritten
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(entity_id='test_domain.world')]) MockEntity(entity_id='test_domain.world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids() assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine async def test_name_which_conflict_with_registered(hass):
def test_name_which_conflict_with_registered(hass):
"""Test not generating conflicting entity ID based on name.""" """Test not generating conflicting entity ID based on name."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass) registry = mock_registry(hass)
@ -470,24 +453,22 @@ def test_name_which_conflict_with_registered(hass):
registry.async_get_or_create( registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world') DOMAIN, 'test', '1234', suggested_object_id='world')
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(name='world')]) MockEntity(name='world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids() assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine async def test_entity_with_name_and_entity_id_getting_registered(hass):
def test_entity_with_name_and_entity_id_getting_registered(hass):
"""Ensure that entity ID is used for registration.""" """Ensure that entity ID is used for registration."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(unique_id='1234', name='bla', MockEntity(unique_id='1234', name='bla',
entity_id='test_domain.world')]) entity_id='test_domain.world')])
assert 'test_domain.world' in hass.states.async_entity_ids() assert 'test_domain.world' in hass.states.async_entity_ids()
@asyncio.coroutine async def test_overriding_name_from_registry(hass):
def test_overriding_name_from_registry(hass):
"""Test that we can override a name via the Entity Registry.""" """Test that we can override a name via the Entity Registry."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
mock_registry(hass, { mock_registry(hass, {
@ -499,7 +480,7 @@ def test_overriding_name_from_registry(hass):
name='Overridden' name='Overridden'
) )
}) })
yield from component.async_add_entities([ await component.async_add_entities([
MockEntity(unique_id='1234', name='Device Name')]) MockEntity(unique_id='1234', name='Device Name')])
state = hass.states.get('test_domain.world') state = hass.states.get('test_domain.world')
@ -507,18 +488,16 @@ def test_overriding_name_from_registry(hass):
assert state.name == 'Overridden' assert state.name == 'Overridden'
@asyncio.coroutine async def test_registry_respect_entity_namespace(hass):
def test_registry_respect_entity_namespace(hass):
"""Test that the registry respects entity namespace.""" """Test that the registry respects entity namespace."""
mock_registry(hass) mock_registry(hass)
platform = MockEntityPlatform(hass, entity_namespace='ns') platform = MockEntityPlatform(hass, entity_namespace='ns')
entity = MockEntity(unique_id='1234', name='Device Name') entity = MockEntity(unique_id='1234', name='Device Name')
yield from platform.async_add_entities([entity]) await platform.async_add_entities([entity])
assert entity.entity_id == 'test_domain.ns_device_name' assert entity.entity_id == 'test_domain.ns_device_name'
@asyncio.coroutine async def test_registry_respect_entity_disabled(hass):
def test_registry_respect_entity_disabled(hass):
"""Test that the registry respects entity disabled.""" """Test that the registry respects entity disabled."""
mock_registry(hass, { mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry( 'test_domain.world': entity_registry.RegistryEntry(
@ -531,7 +510,7 @@ def test_registry_respect_entity_disabled(hass):
}) })
platform = MockEntityPlatform(hass) platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234') entity = MockEntity(unique_id='1234')
yield from platform.async_add_entities([entity]) await platform.async_add_entities([entity])
assert entity.entity_id is None assert entity.entity_id is None
assert hass.states.async_entity_ids() == [] assert hass.states.async_entity_ids() == []
@ -643,12 +622,11 @@ async def test_reset_cancels_retry_setup(hass):
assert ent_platform._async_cancel_retry_setup is None assert ent_platform._async_cancel_retry_setup is None
@asyncio.coroutine async def test_not_fails_with_adding_empty_entities_(hass):
def test_not_fails_with_adding_empty_entities_(hass):
"""Test for not fails on empty entities list.""" """Test for not fails on empty entities list."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([]) await component.async_add_entities([])
assert len(hass.states.async_entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0

File diff suppressed because it is too large Load diff

View file

@ -1,28 +1,18 @@
"""Test Home Assistant icon util methods.""" """Test Home Assistant icon util methods."""
import unittest
class TestIconUtil(unittest.TestCase): def test_battery_icon():
"""Test icon util methods."""
def test_battery_icon(self):
"""Test icon generator for battery sensor.""" """Test icon generator for battery sensor."""
from homeassistant.helpers.icon import icon_for_battery_level from homeassistant.helpers.icon import icon_for_battery_level
assert 'mdi:battery-unknown' == \ assert icon_for_battery_level(None, True) == 'mdi:battery-unknown'
icon_for_battery_level(None, True) assert icon_for_battery_level(None, False) == 'mdi:battery-unknown'
assert 'mdi:battery-unknown' == \
icon_for_battery_level(None, False)
assert 'mdi:battery-outline' == \ assert icon_for_battery_level(5, True) == 'mdi:battery-outline'
icon_for_battery_level(5, True) assert icon_for_battery_level(5, False) == 'mdi:battery-alert'
assert 'mdi:battery-alert' == \
icon_for_battery_level(5, False)
assert 'mdi:battery-charging-100' == \ assert icon_for_battery_level(100, True) == 'mdi:battery-charging-100'
icon_for_battery_level(100, True) assert icon_for_battery_level(100, False) == 'mdi:battery'
assert 'mdi:battery' == \
icon_for_battery_level(100, False)
iconbase = 'mdi:battery' iconbase = 'mdi:battery'
for level in range(0, 100, 5): for level in range(0, 100, 5):

View file

@ -1,27 +1,11 @@
"""Test component helpers.""" """Test component helpers."""
# pylint: disable=protected-access # pylint: disable=protected-access
from collections import OrderedDict from collections import OrderedDict
import unittest
from homeassistant import helpers from homeassistant import helpers
from tests.common import get_test_home_assistant
def test_extract_domain_configs():
class TestHelpers(unittest.TestCase):
"""Tests homeassistant.helpers module."""
# pylint: disable=invalid-name
def setUp(self):
"""Init needed objects."""
self.hass = get_test_home_assistant()
# pylint: disable=invalid-name
def tearDown(self):
"""Stop everything that was started."""
self.hass.stop()
def test_extract_domain_configs(self):
"""Test the extraction of domain configuration.""" """Test the extraction of domain configuration."""
config = { config = {
'zone': None, 'zone': None,
@ -34,7 +18,8 @@ class TestHelpers(unittest.TestCase):
assert set(['zone', 'zone Hallo', 'zone 100']) == \ assert set(['zone', 'zone Hallo', 'zone 100']) == \
set(helpers.extract_domain_configs(config, 'zone')) set(helpers.extract_domain_configs(config, 'zone'))
def test_config_per_platform(self):
def test_config_per_platform():
"""Test config per platform method.""" """Test config per platform method."""
config = OrderedDict([ config = OrderedDict([
('zone', {'platform': 'hello'}), ('zone', {'platform': 'hello'}),

View file

@ -1,11 +1,11 @@
"""Tests for the intent helpers.""" """Tests for the intent helpers."""
import unittest
import voluptuous as vol import voluptuous as vol
import pytest
from homeassistant.core import State from homeassistant.core import State
from homeassistant.helpers import (intent, config_validation as cv) from homeassistant.helpers import (intent, config_validation as cv)
import pytest
class MockIntentHandler(intent.IntentHandler): class MockIntentHandler(intent.IntentHandler):
@ -25,10 +25,7 @@ def test_async_match_state():
assert state is state1 assert state is state1
class TestIntentHandler(unittest.TestCase): def test_async_validate_slots():
"""Test the Home Assistant event helpers."""
def test_async_validate_slots(self):
"""Test async_validate_slots of IntentHandler.""" """Test async_validate_slots of IntentHandler."""
handler1 = MockIntentHandler({ handler1 = MockIntentHandler({
vol.Required('name'): cv.string, vol.Required('name'): cv.string,

View file

@ -1,20 +1,16 @@
"""Tests Home Assistant location helpers.""" """Tests Home Assistant location helpers."""
import unittest
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State from homeassistant.core import State
from homeassistant.helpers import location from homeassistant.helpers import location
class TestHelpersLocation(unittest.TestCase): def test_has_location_with_invalid_states():
"""Set up the tests."""
def test_has_location_with_invalid_states(self):
"""Set up the tests.""" """Set up the tests."""
for state in (None, 1, "hello", object): for state in (None, 1, "hello", object):
assert not location.has_location(state) assert not location.has_location(state)
def test_has_location_with_states_with_invalid_locations(self):
def test_has_location_with_states_with_invalid_locations():
"""Set up the tests.""" """Set up the tests."""
state = State('hello.world', 'invalid', { state = State('hello.world', 'invalid', {
ATTR_LATITUDE: 'no number', ATTR_LATITUDE: 'no number',
@ -22,7 +18,8 @@ class TestHelpersLocation(unittest.TestCase):
}) })
assert not location.has_location(state) assert not location.has_location(state)
def test_has_location_with_states_with_valid_location(self):
def test_has_location_with_states_with_valid_location():
"""Set up the tests.""" """Set up the tests."""
state = State('hello.world', 'invalid', { state = State('hello.world', 'invalid', {
ATTR_LATITUDE: 123.12, ATTR_LATITUDE: 123.12,
@ -30,7 +27,8 @@ class TestHelpersLocation(unittest.TestCase):
}) })
assert location.has_location(state) assert location.has_location(state)
def test_closest_with_no_states_with_location(self):
def test_closest_with_no_states_with_location():
"""Set up the tests.""" """Set up the tests."""
state = State('light.test', 'on') state = State('light.test', 'on')
state2 = State('light.test', 'on', { state2 = State('light.test', 'on', {
@ -44,7 +42,8 @@ class TestHelpersLocation(unittest.TestCase):
assert \ assert \
location.closest(123.45, 123.45, [state, state2, state3]) is None location.closest(123.45, 123.45, [state, state2, state3]) is None
def test_closest_returns_closest(self):
def test_closest_returns_closest():
"""Test .""" """Test ."""
state = State('light.test', 'on', { state = State('light.test', 'on', {
ATTR_LATITUDE: 124.45, ATTR_LATITUDE: 124.45,

View file

@ -1,9 +1,10 @@
"""The tests for the Script component.""" """The tests for the Script component."""
# pylint: disable=protected-access # pylint: disable=protected-access
from datetime import timedelta from datetime import timedelta
import functools as ft
from unittest import mock from unittest import mock
import unittest
import asynctest
import jinja2 import jinja2
import voluptuous as vol import voluptuous as vol
import pytest import pytest
@ -11,30 +12,16 @@ import pytest
from homeassistant import exceptions from homeassistant import exceptions
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
# Otherwise can't test just this file (import order issue) # Otherwise can't test just this file (import order issue)
import homeassistant.components # noqa
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.helpers import script, config_validation as cv from homeassistant.helpers import script, config_validation as cv
from tests.common import fire_time_changed, get_test_home_assistant from tests.common import async_fire_time_changed
ENTITY_ID = 'script.test' ENTITY_ID = 'script.test'
class TestScriptHelper(unittest.TestCase): async def test_firing_event(hass):
"""Test the Script component."""
# pylint: disable=invalid-name
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
# pylint: disable=invalid-name
def tearDown(self):
"""Stop down everything that was started."""
self.hass.stop()
def test_firing_event(self):
"""Test the firing of events.""" """Test the firing of events."""
event = 'test_event' event = 'test_event'
context = Context() context = Context()
@ -45,25 +32,26 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(event) calls.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA({ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA({
'event': event, 'event': event,
'event_data': { 'event_data': {
'hello': 'world' 'hello': 'world'
} }
})) }))
script_obj.run(context=context) await script_obj.async_run(context=context)
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
assert not script_obj.can_cancel assert not script_obj.can_cancel
def test_firing_event_template(self):
async def test_firing_event_template(hass):
"""Test the firing of events.""" """Test the firing of events."""
event = 'test_event' event = 'test_event'
context = Context() context = Context()
@ -74,9 +62,9 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(event) calls.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA({ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA({
'event': event, 'event': event,
'event_data_template': { 'event_data_template': {
'dict': { 'dict': {
@ -90,9 +78,9 @@ class TestScriptHelper(unittest.TestCase):
} }
})) }))
script_obj.run({'is_world': 'yes'}, context=context) await script_obj.async_run({'is_world': 'yes'}, context=context)
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context assert calls[0].context is context
@ -106,7 +94,8 @@ class TestScriptHelper(unittest.TestCase):
} }
assert not script_obj.can_cancel assert not script_obj.can_cancel
def test_calling_service(self):
async def test_calling_service(hass):
"""Test the calling of a service.""" """Test the calling of a service."""
calls = [] calls = []
context = Context() context = Context()
@ -116,22 +105,24 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(service) calls.append(service)
self.hass.services.register('test', 'script', record_call) hass.services.async_register('test', 'script', record_call)
script.call_from_config(self.hass, { hass.async_add_job(
ft.partial(script.call_from_config, hass, {
'service': 'test.script', 'service': 'test.script',
'data': { 'data': {
'hello': 'world' 'hello': 'world'
} }
}, context=context) }, context=context))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
def test_calling_service_template(self):
async def test_calling_service_template(hass):
"""Test the calling of a service.""" """Test the calling of a service."""
calls = [] calls = []
context = Context() context = Context()
@ -141,9 +132,10 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(service) calls.append(service)
self.hass.services.register('test', 'script', record_call) hass.services.async_register('test', 'script', record_call)
script.call_from_config(self.hass, { hass.async_add_job(
ft.partial(script.call_from_config, hass, {
'service_template': """ 'service_template': """
{% if True %} {% if True %}
test.script test.script
@ -159,15 +151,16 @@ class TestScriptHelper(unittest.TestCase):
{% endif %} {% endif %}
""" """
} }
}, {'is_world': 'yes'}, context=context) }, {'is_world': 'yes'}, context=context))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
def test_delay(self):
async def test_delay(hass):
"""Test the delay.""" """Test the delay."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -179,15 +172,15 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': {'seconds': 5}, 'alias': delay_alias}, {'delay': {'seconds': 5}, 'alias': delay_alias},
{'event': event}])) {'event': event}]))
script_obj.run(context=context) await script_obj.async_run(context=context)
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -195,15 +188,16 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
assert events[0].context is context assert events[0].context is context
assert events[1].context is context assert events[1].context is context
def test_delay_template(self):
async def test_delay_template(hass):
"""Test the delay as a template.""" """Test the delay as a template."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -214,15 +208,15 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': '00:00:{{ 5 }}', 'alias': delay_alias}, {'delay': '00:00:{{ 5 }}', 'alias': delay_alias},
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -230,13 +224,14 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
def test_delay_invalid_template(self):
async def test_delay_invalid_template(hass):
"""Test the delay as a template that fails.""" """Test the delay as a template that fails."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -246,23 +241,24 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': '{{ invalid_delay }}'}, {'delay': '{{ invalid_delay }}'},
{'delay': {'seconds': 5}}, {'delay': {'seconds': 5}},
{'event': event}])) {'event': event}]))
with mock.patch.object(script, '_LOGGER') as mock_logger: with mock.patch.object(script, '_LOGGER') as mock_logger:
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert mock_logger.error.called assert mock_logger.error.called
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 1 assert len(events) == 1
def test_delay_complex_template(self):
async def test_delay_complex_template(hass):
"""Test the delay with a working complex template.""" """Test the delay with a working complex template."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -273,17 +269,17 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': { {'delay': {
'seconds': '{{ 5 }}'}, 'seconds': '{{ 5 }}'},
'alias': delay_alias}, 'alias': delay_alias},
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -291,13 +287,14 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
def test_delay_complex_invalid_template(self):
async def test_delay_complex_invalid_template(hass):
"""Test the delay with a complex template that fails.""" """Test the delay with a complex template that fails."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -307,9 +304,9 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': { {'delay': {
'seconds': '{{ invalid_delay }}' 'seconds': '{{ invalid_delay }}'
@ -320,14 +317,15 @@ class TestScriptHelper(unittest.TestCase):
{'event': event}])) {'event': event}]))
with mock.patch.object(script, '_LOGGER') as mock_logger: with mock.patch.object(script, '_LOGGER') as mock_logger:
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert mock_logger.error.called assert mock_logger.error.called
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 1 assert len(events) == 1
def test_cancel_while_delay(self):
async def test_cancel_while_delay(hass):
"""Test the cancelling while the delay is present.""" """Test the cancelling while the delay is present."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -337,31 +335,32 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'delay': {'seconds': 5}}, {'delay': {'seconds': 5}},
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert len(events) == 0 assert len(events) == 0
script_obj.stop() script_obj.async_stop()
assert not script_obj.is_running assert not script_obj.is_running
# Make sure the script is really stopped. # Make sure the script is really stopped.
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 0 assert len(events) == 0
def test_wait_template(self):
async def test_wait_template(hass):
"""Test the wait template.""" """Test the wait template."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -373,33 +372,34 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'wait_template': "{{states.switch.test.state == 'off'}}", {'wait_template': "{{states.switch.test.state == 'off'}}",
'alias': wait_alias}, 'alias': wait_alias},
{'event': event}])) {'event': event}]))
script_obj.run(context=context) await script_obj.async_run(context=context)
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
assert script_obj.last_action == wait_alias assert script_obj.last_action == wait_alias
assert len(events) == 1 assert len(events) == 1
self.hass.states.set('switch.test', 'off') hass.states.async_set('switch.test', 'off')
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
assert events[0].context is context assert events[0].context is context
assert events[1].context is context assert events[1].context is context
def test_wait_template_cancel(self):
async def test_wait_template_cancel(hass):
"""Test the wait template cancel action.""" """Test the wait template cancel action."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -410,36 +410,37 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'wait_template': "{{states.switch.test.state == 'off'}}", {'wait_template': "{{states.switch.test.state == 'off'}}",
'alias': wait_alias}, 'alias': wait_alias},
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
assert script_obj.last_action == wait_alias assert script_obj.last_action == wait_alias
assert len(events) == 1 assert len(events) == 1
script_obj.stop() script_obj.async_stop()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 1 assert len(events) == 1
self.hass.states.set('switch.test', 'off') hass.states.async_set('switch.test', 'off')
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 1 assert len(events) == 1
def test_wait_template_not_schedule(self):
async def test_wait_template_not_schedule(hass):
"""Test the wait template with correct condition.""" """Test the wait template with correct condition."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -449,23 +450,24 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'wait_template': "{{states.switch.test.state == 'on'}}"}, {'wait_template': "{{states.switch.test.state == 'on'}}"},
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
assert len(events) == 2 assert len(events) == 2
def test_wait_template_timeout_halt(self):
async def test_wait_template_timeout_halt(hass):
"""Test the wait template, halt on timeout.""" """Test the wait template, halt on timeout."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -476,11 +478,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'wait_template': "{{states.switch.test.state == 'off'}}", 'wait_template': "{{states.switch.test.state == 'off'}}",
@ -490,8 +492,8 @@ class TestScriptHelper(unittest.TestCase):
}, },
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -499,13 +501,14 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 1 assert len(events) == 1
def test_wait_template_timeout_continue(self):
async def test_wait_template_timeout_continue(hass):
"""Test the wait template with continuing the script.""" """Test the wait template with continuing the script."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -516,11 +519,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'wait_template': "{{states.switch.test.state == 'off'}}", 'wait_template': "{{states.switch.test.state == 'off'}}",
@ -530,8 +533,8 @@ class TestScriptHelper(unittest.TestCase):
}, },
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -539,13 +542,14 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
def test_wait_template_timeout_default(self):
async def test_wait_template_timeout_default(hass):
"""Test the wait template with default contiune.""" """Test the wait template with default contiune."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -556,11 +560,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'wait_template': "{{states.switch.test.state == 'off'}}", 'wait_template': "{{states.switch.test.state == 'off'}}",
@ -569,8 +573,8 @@ class TestScriptHelper(unittest.TestCase):
}, },
{'event': event}])) {'event': event}]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
@ -578,13 +582,14 @@ class TestScriptHelper(unittest.TestCase):
assert len(events) == 1 assert len(events) == 1
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
def test_wait_template_variables(self):
async def test_wait_template_variables(hass):
"""Test the wait template with variables.""" """Test the wait template with variables."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -595,33 +600,34 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('switch.test', 'on') hass.states.async_set('switch.test', 'on')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'wait_template': "{{is_state(data, 'off')}}", {'wait_template': "{{is_state(data, 'off')}}",
'alias': wait_alias}, 'alias': wait_alias},
{'event': event}])) {'event': event}]))
script_obj.run({ await script_obj.async_run({
'data': 'switch.test' 'data': 'switch.test'
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert script_obj.can_cancel assert script_obj.can_cancel
assert script_obj.last_action == wait_alias assert script_obj.last_action == wait_alias
assert len(events) == 1 assert len(events) == 1
self.hass.states.set('switch.test', 'off') hass.states.async_set('switch.test', 'off')
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
def test_passing_variables_to_script(self):
async def test_passing_variables_to_script(hass):
"""Test if we can pass variables to script.""" """Test if we can pass variables to script."""
calls = [] calls = []
@ -630,9 +636,9 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
calls.append(service) calls.append(service)
self.hass.services.register('test', 'script', record_call) hass.services.async_register('test', 'script', record_call)
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{ {
'service': 'test.script', 'service': 'test.script',
'data_template': { 'data_template': {
@ -647,27 +653,28 @@ class TestScriptHelper(unittest.TestCase):
}, },
}])) }]))
script_obj.run({ await script_obj.async_run({
'greeting': 'world', 'greeting': 'world',
'greeting2': 'universe', 'greeting2': 'universe',
'delay_period': '00:00:05' 'delay_period': '00:00:05'
}) })
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.is_running assert script_obj.is_running
assert len(calls) == 1 assert len(calls) == 1
assert calls[-1].data['hello'] == 'world' assert calls[-1].data['hello'] == 'world'
future = dt_util.utcnow() + timedelta(seconds=5) future = dt_util.utcnow() + timedelta(seconds=5)
fire_time_changed(self.hass, future) async_fire_time_changed(hass, future)
self.hass.block_till_done() await hass.async_block_till_done()
assert not script_obj.is_running assert not script_obj.is_running
assert len(calls) == 2 assert len(calls) == 2
assert calls[-1].data['hello'] == 'universe' assert calls[-1].data['hello'] == 'universe'
def test_condition(self):
async def test_condition(hass):
"""Test if we can use conditions in a script.""" """Test if we can use conditions in a script."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -677,11 +684,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('test.entity', 'hello') hass.states.async_set('test.entity', 'hello')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'condition': 'template', 'condition': 'template',
@ -690,18 +697,19 @@ class TestScriptHelper(unittest.TestCase):
{'event': event}, {'event': event},
])) ]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert len(events) == 2 assert len(events) == 2
self.hass.states.set('test.entity', 'goodbye') hass.states.async_set('test.entity', 'goodbye')
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert len(events) == 3 assert len(events) == 3
@mock.patch('homeassistant.helpers.script.condition.async_from_config')
def test_condition_created_once(self, async_from_config): @asynctest.patch('homeassistant.helpers.script.condition.async_from_config')
async def test_condition_created_once(async_from_config, hass):
"""Test that the conditions do not get created multiple times.""" """Test that the conditions do not get created multiple times."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -711,11 +719,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('test.entity', 'hello') hass.states.async_set('test.entity', 'hello')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'condition': 'template', 'condition': 'template',
@ -724,13 +732,14 @@ class TestScriptHelper(unittest.TestCase):
{'event': event}, {'event': event},
])) ]))
script_obj.run() await script_obj.async_run()
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert async_from_config.call_count == 1 assert async_from_config.call_count == 1
assert len(script_obj._config_cache) == 1 assert len(script_obj._config_cache) == 1
def test_all_conditions_cached(self):
async def test_all_conditions_cached(hass):
"""Test that multiple conditions get cached.""" """Test that multiple conditions get cached."""
event = 'test_event' event = 'test_event'
events = [] events = []
@ -740,11 +749,11 @@ class TestScriptHelper(unittest.TestCase):
"""Add recorded event to set.""" """Add recorded event to set."""
events.append(event) events.append(event)
self.hass.bus.listen(event, record_event) hass.bus.async_listen(event, record_event)
self.hass.states.set('test.entity', 'hello') hass.states.async_set('test.entity', 'hello')
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{ {
'condition': 'template', 'condition': 'template',
@ -757,15 +766,16 @@ class TestScriptHelper(unittest.TestCase):
{'event': event}, {'event': event},
])) ]))
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert len(script_obj._config_cache) == 2 assert len(script_obj._config_cache) == 2
def test_last_triggered(self):
async def test_last_triggered(hass):
"""Test the last_triggered.""" """Test the last_triggered."""
event = 'test_event' event = 'test_event'
script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'event': event}, {'event': event},
{'delay': {'seconds': 5}}, {'delay': {'seconds': 5}},
{'event': event}])) {'event': event}]))
@ -775,8 +785,8 @@ class TestScriptHelper(unittest.TestCase):
time = dt_util.utcnow() time = dt_util.utcnow()
with mock.patch('homeassistant.helpers.script.date_util.utcnow', with mock.patch('homeassistant.helpers.script.date_util.utcnow',
return_value=time): return_value=time):
script_obj.run() await script_obj.async_run()
self.hass.block_till_done() await hass.async_block_till_done()
assert script_obj.last_triggered == time assert script_obj.last_triggered == time
@ -874,17 +884,19 @@ def test_log_exception():
for exc, msg in ( for exc, msg in (
(vol.Invalid("Invalid number"), 'Invalid data'), (vol.Invalid("Invalid number"), 'Invalid data'),
(exceptions.TemplateError(jinja2.TemplateError('Unclosed bracket')), (exceptions.TemplateError(
jinja2.TemplateError('Unclosed bracket')),
'Error rendering template'), 'Error rendering template'),
(exceptions.Unauthorized(), 'Unauthorized'), (exceptions.Unauthorized(), 'Unauthorized'),
(exceptions.ServiceNotFound('light', 'turn_on'), 'Service not found'), (exceptions.ServiceNotFound('light', 'turn_on'),
'Service not found'),
(ValueError("Cannot parse JSON"), 'Unknown error'), (ValueError("Cannot parse JSON"), 'Unknown error'),
): ):
logger = mock.Mock() logger = mock.Mock()
script_obj.async_log_exception(logger, 'Test error', exc) script_obj.async_log_exception(logger, 'Test error', exc)
assert len(logger.mock_calls) == 1 assert len(logger.mock_calls) == 1
p_format, p_msg_base, p_error_desc, p_action_type, p_step, p_error = \ _, _, p_error_desc, p_action_type, p_step, p_error = \
logger.mock_calls[0][1] logger.mock_calls[0][1]
assert p_error_desc == msg assert p_error_desc == msg

View file

@ -1,13 +1,12 @@
"""Test state helpers.""" """Test state helpers."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
import unittest
from unittest.mock import patch from unittest.mock import patch
import pytest
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.setup import async_setup_component
from homeassistant.const import (SERVICE_TURN_ON, SERVICE_TURN_OFF) from homeassistant.const import (SERVICE_TURN_ON, SERVICE_TURN_OFF)
from homeassistant.util.async_ import run_coroutine_threadsafe
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.helpers import state from homeassistant.helpers import state
from homeassistant.const import ( from homeassistant.const import (
@ -18,8 +17,7 @@ from homeassistant.const import (
from homeassistant.components.sun import (STATE_ABOVE_HORIZON, from homeassistant.components.sun import (STATE_ABOVE_HORIZON,
STATE_BELOW_HORIZON) STATE_BELOW_HORIZON)
from tests.common import get_test_home_assistant, mock_service from tests.common import async_mock_service
import pytest
@asyncio.coroutine @asyncio.coroutine
@ -82,139 +80,132 @@ def test_call_to_component(hass):
context=context) context=context)
class TestStateHelpers(unittest.TestCase): async def test_get_changed_since(hass):
"""Test the Home Assistant event helpers."""
def setUp(self): # pylint: disable=invalid-name
"""Run when tests are started."""
self.hass = get_test_home_assistant()
run_coroutine_threadsafe(async_setup_component(
self.hass, 'homeassistant', {}), self.hass.loop).result()
def tearDown(self): # pylint: disable=invalid-name
"""Stop when tests are finished."""
self.hass.stop()
def test_get_changed_since(self):
"""Test get_changed_since.""" """Test get_changed_since."""
point1 = dt_util.utcnow() point1 = dt_util.utcnow()
point2 = point1 + timedelta(seconds=5) point2 = point1 + timedelta(seconds=5)
point3 = point2 + timedelta(seconds=5) point3 = point2 + timedelta(seconds=5)
with patch('homeassistant.core.dt_util.utcnow', return_value=point1): with patch('homeassistant.core.dt_util.utcnow', return_value=point1):
self.hass.states.set('light.test', 'on') hass.states.async_set('light.test', 'on')
state1 = self.hass.states.get('light.test') state1 = hass.states.get('light.test')
with patch('homeassistant.core.dt_util.utcnow', return_value=point2): with patch('homeassistant.core.dt_util.utcnow', return_value=point2):
self.hass.states.set('light.test2', 'on') hass.states.async_set('light.test2', 'on')
state2 = self.hass.states.get('light.test2') state2 = hass.states.get('light.test2')
with patch('homeassistant.core.dt_util.utcnow', return_value=point3): with patch('homeassistant.core.dt_util.utcnow', return_value=point3):
self.hass.states.set('light.test3', 'on') hass.states.async_set('light.test3', 'on')
state3 = self.hass.states.get('light.test3') state3 = hass.states.get('light.test3')
assert [state2, state3] == \ assert [state2, state3] == \
state.get_changed_since([state1, state2, state3], point2) state.get_changed_since([state1, state2, state3], point2)
def test_reproduce_with_no_entity(self):
async def test_reproduce_with_no_entity(hass):
"""Test reproduce_state with no entity.""" """Test reproduce_state with no entity."""
calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) calls = async_mock_service(hass, 'light', SERVICE_TURN_ON)
state.reproduce_state(self.hass, ha.State('light.test', 'on')) await state.async_reproduce_state(hass, ha.State('light.test', 'on'))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 0 assert len(calls) == 0
assert self.hass.states.get('light.test') is None assert hass.states.get('light.test') is None
def test_reproduce_turn_on(self):
async def test_reproduce_turn_on(hass):
"""Test reproduce_state with SERVICE_TURN_ON.""" """Test reproduce_state with SERVICE_TURN_ON."""
calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) calls = async_mock_service(hass, 'light', SERVICE_TURN_ON)
self.hass.states.set('light.test', 'off') hass.states.async_set('light.test', 'off')
state.reproduce_state(self.hass, ha.State('light.test', 'on')) await state.async_reproduce_state(hass, ha.State('light.test', 'on'))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) > 0 assert len(calls) > 0
last_call = calls[-1] last_call = calls[-1]
assert 'light' == last_call.domain assert last_call.domain == 'light'
assert SERVICE_TURN_ON == last_call.service assert SERVICE_TURN_ON == last_call.service
assert ['light.test'] == last_call.data.get('entity_id') assert ['light.test'] == last_call.data.get('entity_id')
def test_reproduce_turn_off(self):
async def test_reproduce_turn_off(hass):
"""Test reproduce_state with SERVICE_TURN_OFF.""" """Test reproduce_state with SERVICE_TURN_OFF."""
calls = mock_service(self.hass, 'light', SERVICE_TURN_OFF) calls = async_mock_service(hass, 'light', SERVICE_TURN_OFF)
self.hass.states.set('light.test', 'on') hass.states.async_set('light.test', 'on')
state.reproduce_state(self.hass, ha.State('light.test', 'off')) await state.async_reproduce_state(hass, ha.State('light.test', 'off'))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) > 0 assert len(calls) > 0
last_call = calls[-1] last_call = calls[-1]
assert 'light' == last_call.domain assert last_call.domain == 'light'
assert SERVICE_TURN_OFF == last_call.service assert SERVICE_TURN_OFF == last_call.service
assert ['light.test'] == last_call.data.get('entity_id') assert ['light.test'] == last_call.data.get('entity_id')
def test_reproduce_complex_data(self):
"""Test reproduce_state with complex service data."""
calls = mock_service(self.hass, 'light', SERVICE_TURN_ON)
self.hass.states.set('light.test', 'off') async def test_reproduce_complex_data(hass):
"""Test reproduce_state with complex service data."""
calls = async_mock_service(hass, 'light', SERVICE_TURN_ON)
hass.states.async_set('light.test', 'off')
complex_data = ['hello', {'11': '22'}] complex_data = ['hello', {'11': '22'}]
state.reproduce_state(self.hass, ha.State('light.test', 'on', { await state.async_reproduce_state(hass, ha.State('light.test', 'on', {
'complex': complex_data 'complex': complex_data
})) }))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) > 0 assert len(calls) > 0
last_call = calls[-1] last_call = calls[-1]
assert 'light' == last_call.domain assert last_call.domain == 'light'
assert SERVICE_TURN_ON == last_call.service assert SERVICE_TURN_ON == last_call.service
assert complex_data == last_call.data.get('complex') assert complex_data == last_call.data.get('complex')
def test_reproduce_bad_state(self):
async def test_reproduce_bad_state(hass):
"""Test reproduce_state with bad state.""" """Test reproduce_state with bad state."""
calls = mock_service(self.hass, 'light', SERVICE_TURN_ON) calls = async_mock_service(hass, 'light', SERVICE_TURN_ON)
self.hass.states.set('light.test', 'off') hass.states.async_set('light.test', 'off')
state.reproduce_state(self.hass, ha.State('light.test', 'bad')) await state.async_reproduce_state(hass, ha.State('light.test', 'bad'))
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 0 assert len(calls) == 0
assert 'off' == self.hass.states.get('light.test').state assert hass.states.get('light.test').state == 'off'
def test_as_number_states(self):
async def test_as_number_states(hass):
"""Test state_as_number with states.""" """Test state_as_number with states."""
zero_states = (STATE_OFF, STATE_CLOSED, STATE_UNLOCKED, zero_states = (STATE_OFF, STATE_CLOSED, STATE_UNLOCKED,
STATE_BELOW_HORIZON, STATE_NOT_HOME) STATE_BELOW_HORIZON, STATE_NOT_HOME)
one_states = (STATE_ON, STATE_OPEN, STATE_LOCKED, STATE_ABOVE_HORIZON, one_states = (STATE_ON, STATE_OPEN, STATE_LOCKED, STATE_ABOVE_HORIZON,
STATE_HOME) STATE_HOME)
for _state in zero_states: for _state in zero_states:
assert 0 == state.state_as_number( assert state.state_as_number(ha.State('domain.test', _state, {})) == 0
ha.State('domain.test', _state, {}))
for _state in one_states: for _state in one_states:
assert 1 == state.state_as_number( assert state.state_as_number(ha.State('domain.test', _state, {})) == 1
ha.State('domain.test', _state, {}))
def test_as_number_coercion(self):
async def test_as_number_coercion(hass):
"""Test state_as_number with number.""" """Test state_as_number with number."""
for _state in ('0', '0.0', 0, 0.0): for _state in ('0', '0.0', 0, 0.0):
assert 0.0 == state.state_as_number( assert state.state_as_number(
ha.State('domain.test', _state, {})) ha.State('domain.test', _state, {})) == 0.0
for _state in ('1', '1.0', 1, 1.0): for _state in ('1', '1.0', 1, 1.0):
assert 1.0 == state.state_as_number( assert state.state_as_number(
ha.State('domain.test', _state, {})) ha.State('domain.test', _state, {})) == 1.0
def test_as_number_invalid_cases(self):
async def test_as_number_invalid_cases(hass):
"""Test state_as_number with invalid cases.""" """Test state_as_number with invalid cases."""
for _state in ('', 'foo', 'foo.bar', None, False, True, object, for _state in ('', 'foo', 'foo.bar', None, False, True, object,
object()): object()):

View file

@ -1,6 +1,5 @@
"""The tests for the Sun helpers.""" """The tests for the Sun helpers."""
# pylint: disable=protected-access # pylint: disable=protected-access
import unittest
from unittest.mock import patch from unittest.mock import patch
from datetime import timedelta, datetime from datetime import timedelta, datetime
@ -8,22 +7,8 @@ from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.helpers.sun as sun import homeassistant.helpers.sun as sun
from tests.common import get_test_home_assistant
def test_next_events(hass):
# pylint: disable=invalid-name
class TestSun(unittest.TestCase):
"""Test the sun helpers."""
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
def tearDown(self):
"""Stop everything that was started."""
self.hass.stop()
def test_next_events(self):
"""Test retrieving next sun events.""" """Test retrieving next sun events."""
utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC)
from astral import Astral from astral import Astral
@ -31,8 +16,8 @@ class TestSun(unittest.TestCase):
astral = Astral() astral = Astral()
utc_today = utc_now.date() utc_today = utc_now.date()
latitude = self.hass.config.latitude latitude = hass.config.latitude
longitude = self.hass.config.longitude longitude = hass.config.longitude
mod = -1 mod = -1
while True: while True:
@ -85,19 +70,20 @@ class TestSun(unittest.TestCase):
with patch('homeassistant.helpers.condition.dt_util.utcnow', with patch('homeassistant.helpers.condition.dt_util.utcnow',
return_value=utc_now): return_value=utc_now):
assert next_dawn == sun.get_astral_event_next( assert next_dawn == sun.get_astral_event_next(
self.hass, 'dawn') hass, 'dawn')
assert next_dusk == sun.get_astral_event_next( assert next_dusk == sun.get_astral_event_next(
self.hass, 'dusk') hass, 'dusk')
assert next_midnight == sun.get_astral_event_next( assert next_midnight == sun.get_astral_event_next(
self.hass, 'solar_midnight') hass, 'solar_midnight')
assert next_noon == sun.get_astral_event_next( assert next_noon == sun.get_astral_event_next(
self.hass, 'solar_noon') hass, 'solar_noon')
assert next_rising == sun.get_astral_event_next( assert next_rising == sun.get_astral_event_next(
self.hass, SUN_EVENT_SUNRISE) hass, SUN_EVENT_SUNRISE)
assert next_setting == sun.get_astral_event_next( assert next_setting == sun.get_astral_event_next(
self.hass, SUN_EVENT_SUNSET) hass, SUN_EVENT_SUNSET)
def test_date_events(self):
def test_date_events(hass):
"""Test retrieving next sun events.""" """Test retrieving next sun events."""
utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC)
from astral import Astral from astral import Astral
@ -105,8 +91,8 @@ class TestSun(unittest.TestCase):
astral = Astral() astral = Astral()
utc_today = utc_now.date() utc_today = utc_now.date()
latitude = self.hass.config.latitude latitude = hass.config.latitude
longitude = self.hass.config.longitude longitude = hass.config.longitude
dawn = astral.dawn_utc(utc_today, latitude, longitude) dawn = astral.dawn_utc(utc_today, latitude, longitude)
dusk = astral.dusk_utc(utc_today, latitude, longitude) dusk = astral.dusk_utc(utc_today, latitude, longitude)
@ -116,19 +102,20 @@ class TestSun(unittest.TestCase):
sunset = astral.sunset_utc(utc_today, latitude, longitude) sunset = astral.sunset_utc(utc_today, latitude, longitude)
assert dawn == sun.get_astral_event_date( assert dawn == sun.get_astral_event_date(
self.hass, 'dawn', utc_today) hass, 'dawn', utc_today)
assert dusk == sun.get_astral_event_date( assert dusk == sun.get_astral_event_date(
self.hass, 'dusk', utc_today) hass, 'dusk', utc_today)
assert midnight == sun.get_astral_event_date( assert midnight == sun.get_astral_event_date(
self.hass, 'solar_midnight', utc_today) hass, 'solar_midnight', utc_today)
assert noon == sun.get_astral_event_date( assert noon == sun.get_astral_event_date(
self.hass, 'solar_noon', utc_today) hass, 'solar_noon', utc_today)
assert sunrise == sun.get_astral_event_date( assert sunrise == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNRISE, utc_today) hass, SUN_EVENT_SUNRISE, utc_today)
assert sunset == sun.get_astral_event_date( assert sunset == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNSET, utc_today) hass, SUN_EVENT_SUNSET, utc_today)
def test_date_events_default_date(self):
def test_date_events_default_date(hass):
"""Test retrieving next sun events.""" """Test retrieving next sun events."""
utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC)
from astral import Astral from astral import Astral
@ -136,8 +123,8 @@ class TestSun(unittest.TestCase):
astral = Astral() astral = Astral()
utc_today = utc_now.date() utc_today = utc_now.date()
latitude = self.hass.config.latitude latitude = hass.config.latitude
longitude = self.hass.config.longitude longitude = hass.config.longitude
dawn = astral.dawn_utc(utc_today, latitude, longitude) dawn = astral.dawn_utc(utc_today, latitude, longitude)
dusk = astral.dusk_utc(utc_today, latitude, longitude) dusk = astral.dusk_utc(utc_today, latitude, longitude)
@ -148,19 +135,20 @@ class TestSun(unittest.TestCase):
with patch('homeassistant.util.dt.now', return_value=utc_now): with patch('homeassistant.util.dt.now', return_value=utc_now):
assert dawn == sun.get_astral_event_date( assert dawn == sun.get_astral_event_date(
self.hass, 'dawn', utc_today) hass, 'dawn', utc_today)
assert dusk == sun.get_astral_event_date( assert dusk == sun.get_astral_event_date(
self.hass, 'dusk', utc_today) hass, 'dusk', utc_today)
assert midnight == sun.get_astral_event_date( assert midnight == sun.get_astral_event_date(
self.hass, 'solar_midnight', utc_today) hass, 'solar_midnight', utc_today)
assert noon == sun.get_astral_event_date( assert noon == sun.get_astral_event_date(
self.hass, 'solar_noon', utc_today) hass, 'solar_noon', utc_today)
assert sunrise == sun.get_astral_event_date( assert sunrise == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNRISE, utc_today) hass, SUN_EVENT_SUNRISE, utc_today)
assert sunset == sun.get_astral_event_date( assert sunset == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNSET, utc_today) hass, SUN_EVENT_SUNSET, utc_today)
def test_date_events_accepts_datetime(self):
def test_date_events_accepts_datetime(hass):
"""Test retrieving next sun events.""" """Test retrieving next sun events."""
utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 8, 0, 0, tzinfo=dt_util.UTC)
from astral import Astral from astral import Astral
@ -168,8 +156,8 @@ class TestSun(unittest.TestCase):
astral = Astral() astral = Astral()
utc_today = utc_now.date() utc_today = utc_now.date()
latitude = self.hass.config.latitude latitude = hass.config.latitude
longitude = self.hass.config.longitude longitude = hass.config.longitude
dawn = astral.dawn_utc(utc_today, latitude, longitude) dawn = astral.dawn_utc(utc_today, latitude, longitude)
dusk = astral.dusk_utc(utc_today, latitude, longitude) dusk = astral.dusk_utc(utc_today, latitude, longitude)
@ -179,52 +167,54 @@ class TestSun(unittest.TestCase):
sunset = astral.sunset_utc(utc_today, latitude, longitude) sunset = astral.sunset_utc(utc_today, latitude, longitude)
assert dawn == sun.get_astral_event_date( assert dawn == sun.get_astral_event_date(
self.hass, 'dawn', utc_now) hass, 'dawn', utc_now)
assert dusk == sun.get_astral_event_date( assert dusk == sun.get_astral_event_date(
self.hass, 'dusk', utc_now) hass, 'dusk', utc_now)
assert midnight == sun.get_astral_event_date( assert midnight == sun.get_astral_event_date(
self.hass, 'solar_midnight', utc_now) hass, 'solar_midnight', utc_now)
assert noon == sun.get_astral_event_date( assert noon == sun.get_astral_event_date(
self.hass, 'solar_noon', utc_now) hass, 'solar_noon', utc_now)
assert sunrise == sun.get_astral_event_date( assert sunrise == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNRISE, utc_now) hass, SUN_EVENT_SUNRISE, utc_now)
assert sunset == sun.get_astral_event_date( assert sunset == sun.get_astral_event_date(
self.hass, SUN_EVENT_SUNSET, utc_now) hass, SUN_EVENT_SUNSET, utc_now)
def test_is_up(self):
def test_is_up(hass):
"""Test retrieving next sun events.""" """Test retrieving next sun events."""
utc_now = datetime(2016, 11, 1, 12, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 12, 0, 0, tzinfo=dt_util.UTC)
with patch('homeassistant.helpers.condition.dt_util.utcnow', with patch('homeassistant.helpers.condition.dt_util.utcnow',
return_value=utc_now): return_value=utc_now):
assert not sun.is_up(self.hass) assert not sun.is_up(hass)
utc_now = datetime(2016, 11, 1, 18, 0, 0, tzinfo=dt_util.UTC) utc_now = datetime(2016, 11, 1, 18, 0, 0, tzinfo=dt_util.UTC)
with patch('homeassistant.helpers.condition.dt_util.utcnow', with patch('homeassistant.helpers.condition.dt_util.utcnow',
return_value=utc_now): return_value=utc_now):
assert sun.is_up(self.hass) assert sun.is_up(hass)
def test_norway_in_june(self):
def test_norway_in_june(hass):
"""Test location in Norway where the sun doesn't set in summer.""" """Test location in Norway where the sun doesn't set in summer."""
self.hass.config.latitude = 69.6 hass.config.latitude = 69.6
self.hass.config.longitude = 18.8 hass.config.longitude = 18.8
june = datetime(2016, 6, 1, tzinfo=dt_util.UTC) june = datetime(2016, 6, 1, tzinfo=dt_util.UTC)
print(sun.get_astral_event_date(self.hass, SUN_EVENT_SUNRISE, print(sun.get_astral_event_date(hass, SUN_EVENT_SUNRISE,
datetime(2017, 7, 25))) datetime(2017, 7, 25)))
print(sun.get_astral_event_date(self.hass, SUN_EVENT_SUNSET, print(sun.get_astral_event_date(hass, SUN_EVENT_SUNSET,
datetime(2017, 7, 25))) datetime(2017, 7, 25)))
print(sun.get_astral_event_date(self.hass, SUN_EVENT_SUNRISE, print(sun.get_astral_event_date(hass, SUN_EVENT_SUNRISE,
datetime(2017, 7, 26))) datetime(2017, 7, 26)))
print(sun.get_astral_event_date(self.hass, SUN_EVENT_SUNSET, print(sun.get_astral_event_date(hass, SUN_EVENT_SUNSET,
datetime(2017, 7, 26))) datetime(2017, 7, 26)))
assert sun.get_astral_event_next(self.hass, SUN_EVENT_SUNRISE, june) \ assert sun.get_astral_event_next(hass, SUN_EVENT_SUNRISE, june) \
== datetime(2016, 7, 25, 23, 23, 39, tzinfo=dt_util.UTC) == datetime(2016, 7, 25, 23, 23, 39, tzinfo=dt_util.UTC)
assert sun.get_astral_event_next(self.hass, SUN_EVENT_SUNSET, june) \ assert sun.get_astral_event_next(hass, SUN_EVENT_SUNSET, june) \
== datetime(2016, 7, 26, 22, 19, 1, tzinfo=dt_util.UTC) == datetime(2016, 7, 26, 22, 19, 1, tzinfo=dt_util.UTC)
assert sun.get_astral_event_date(self.hass, SUN_EVENT_SUNRISE, june) \ assert sun.get_astral_event_date(hass, SUN_EVENT_SUNRISE, june) \
is None is None
assert sun.get_astral_event_date(self.hass, SUN_EVENT_SUNSET, june) \ assert sun.get_astral_event_date(hass, SUN_EVENT_SUNSET, june) \
is None is None

View file

@ -1,50 +1,34 @@
"""Tests Home Assistant temperature helpers.""" """Tests Home Assistant temperature helpers."""
import unittest import pytest
from tests.common import get_test_home_assistant
from homeassistant.const import ( from homeassistant.const import (
TEMP_CELSIUS, PRECISION_WHOLE, TEMP_FAHRENHEIT, PRECISION_HALVES, TEMP_CELSIUS, PRECISION_WHOLE, TEMP_FAHRENHEIT, PRECISION_HALVES,
PRECISION_TENTHS) PRECISION_TENTHS)
from homeassistant.helpers.temperature import display_temp from homeassistant.helpers.temperature import display_temp
from homeassistant.util.unit_system import METRIC_SYSTEM
import pytest
TEMP = 24.636626 TEMP = 24.636626
class TestHelpersTemperature(unittest.TestCase): def test_temperature_not_a_number(hass):
"""Set up the temperature tests."""
def setUp(self):
"""Set up the tests."""
self.hass = get_test_home_assistant()
self.hass.config.unit_system = METRIC_SYSTEM
def tearDown(self):
"""Stop down stuff we started."""
self.hass.stop()
def test_temperature_not_a_number(self):
"""Test that temperature is a number.""" """Test that temperature is a number."""
temp = "Temperature" temp = "Temperature"
with pytest.raises(Exception) as exception: with pytest.raises(Exception) as exception:
display_temp(self.hass, temp, TEMP_CELSIUS, PRECISION_HALVES) display_temp(hass, temp, TEMP_CELSIUS, PRECISION_HALVES)
assert "Temperature is not a number: {}".format(temp) \ assert "Temperature is not a number: {}".format(temp) \
in str(exception) in str(exception)
def test_celsius_halves(self):
def test_celsius_halves(hass):
"""Test temperature to celsius rounding to halves.""" """Test temperature to celsius rounding to halves."""
assert 24.5 == display_temp( assert display_temp(hass, TEMP, TEMP_CELSIUS, PRECISION_HALVES) == 24.5
self.hass, TEMP, TEMP_CELSIUS, PRECISION_HALVES)
def test_celsius_tenths(self):
def test_celsius_tenths(hass):
"""Test temperature to celsius rounding to tenths.""" """Test temperature to celsius rounding to tenths."""
assert 24.6 == display_temp( assert display_temp(hass, TEMP, TEMP_CELSIUS, PRECISION_TENTHS) == 24.6
self.hass, TEMP, TEMP_CELSIUS, PRECISION_TENTHS)
def test_fahrenheit_wholes(self):
def test_fahrenheit_wholes(hass):
"""Test temperature to fahrenheit rounding to wholes.""" """Test temperature to fahrenheit rounding to wholes."""
assert -4 == display_temp( assert display_temp(hass, TEMP, TEMP_FAHRENHEIT, PRECISION_WHOLE) == -4
self.hass, TEMP, TEMP_FAHRENHEIT, PRECISION_WHOLE)

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,6 @@
"""Test check_config script.""" """Test check_config script."""
import asyncio
import logging import logging
import os # noqa: F401 pylint: disable=unused-import import os # noqa: F401 pylint: disable=unused-import
import unittest
from unittest.mock import patch from unittest.mock import patch
import homeassistant.scripts.check_config as check_config import homeassistant.scripts.check_config as check_config
@ -36,25 +34,9 @@ def normalize_yaml_files(check_dict):
for key in sorted(check_dict['yaml_files'].keys())] for key in sorted(check_dict['yaml_files'].keys())]
# pylint: disable=unsubscriptable-object
class TestCheckConfig(unittest.TestCase):
"""Tests for the homeassistant.scripts.check_config module."""
def setUp(self):
"""Prepare the test."""
# Somewhere in the tests our event loop gets killed,
# this ensures we have one.
try:
asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
# Will allow seeing full diff
self.maxDiff = None # pylint: disable=invalid-name
# pylint: disable=no-self-use,invalid-name # pylint: disable=no-self-use,invalid-name
@patch('os.path.isfile', return_value=True) @patch('os.path.isfile', return_value=True)
def test_bad_core_config(self, isfile_patch): def test_bad_core_config(isfile_patch, loop):
"""Test a bad core config setup.""" """Test a bad core config setup."""
files = { files = {
YAML_CONFIG_FILE: BAD_CORE_CONFIG, YAML_CONFIG_FILE: BAD_CORE_CONFIG,
@ -64,8 +46,9 @@ class TestCheckConfig(unittest.TestCase):
assert res['except'].keys() == {'homeassistant'} assert res['except'].keys() == {'homeassistant'}
assert res['except']['homeassistant'][1] == {'unit_system': 'bad'} assert res['except']['homeassistant'][1] == {'unit_system': 'bad'}
@patch('os.path.isfile', return_value=True) @patch('os.path.isfile', return_value=True)
def test_config_platform_valid(self, isfile_patch): def test_config_platform_valid(isfile_patch, loop):
"""Test a valid platform setup.""" """Test a valid platform setup."""
files = { files = {
YAML_CONFIG_FILE: BASE_CONFIG + 'light:\n platform: demo', YAML_CONFIG_FILE: BASE_CONFIG + 'light:\n platform: demo',
@ -79,8 +62,9 @@ class TestCheckConfig(unittest.TestCase):
assert res['secrets'] == {} assert res['secrets'] == {}
assert len(res['yaml_files']) == 1 assert len(res['yaml_files']) == 1
@patch('os.path.isfile', return_value=True) @patch('os.path.isfile', return_value=True)
def test_component_platform_not_found(self, isfile_patch): def test_component_platform_not_found(isfile_patch, loop):
"""Test errors if component or platform not found.""" """Test errors if component or platform not found."""
# Make sure they don't exist # Make sure they don't exist
files = { files = {
@ -111,8 +95,9 @@ class TestCheckConfig(unittest.TestCase):
assert res['secrets'] == {} assert res['secrets'] == {}
assert len(res['yaml_files']) == 1 assert len(res['yaml_files']) == 1
@patch('os.path.isfile', return_value=True) @patch('os.path.isfile', return_value=True)
def test_secrets(self, isfile_patch): def test_secrets(isfile_patch, loop):
"""Test secrets config checking method.""" """Test secrets config checking method."""
secrets_path = get_test_config_dir('secrets.yaml') secrets_path = get_test_config_dir('secrets.yaml')
@ -146,8 +131,9 @@ class TestCheckConfig(unittest.TestCase):
assert normalize_yaml_files(res) == [ assert normalize_yaml_files(res) == [
'.../configuration.yaml', '.../secrets.yaml'] '.../configuration.yaml', '.../secrets.yaml']
@patch('os.path.isfile', return_value=True) @patch('os.path.isfile', return_value=True)
def test_package_invalid(self, isfile_patch): def test_package_invalid(isfile_patch, loop):
"""Test a valid platform setup.""" """Test a valid platform setup."""
files = { files = {
YAML_CONFIG_FILE: BASE_CONFIG + ( YAML_CONFIG_FILE: BASE_CONFIG + (
@ -168,7 +154,8 @@ class TestCheckConfig(unittest.TestCase):
assert res['secrets'] == {} assert res['secrets'] == {}
assert len(res['yaml_files']) == 1 assert len(res['yaml_files']) == 1
def test_bootstrap_error(self):
def test_bootstrap_error(loop):
"""Test a valid platform setup.""" """Test a valid platform setup."""
files = { files = {
YAML_CONFIG_FILE: BASE_CONFIG + 'automation: !include no.yaml', YAML_CONFIG_FILE: BASE_CONFIG + 'automation: !include no.yaml',

View file

@ -1,19 +1,15 @@
"""Test script init.""" """Test script init."""
import unittest
from unittest.mock import patch from unittest.mock import patch
import homeassistant.scripts as scripts import homeassistant.scripts as scripts
class TestScripts(unittest.TestCase):
"""Tests homeassistant.scripts module."""
@patch('homeassistant.scripts.get_default_config_dir', @patch('homeassistant.scripts.get_default_config_dir',
return_value='/default') return_value='/default')
def test_config_per_platform(self, mock_def): def test_config_per_platform(mock_def):
"""Test config per platform method.""" """Test config per platform method."""
self.assertEqual(scripts.get_default_config_dir(), '/default') assert scripts.get_default_config_dir() == '/default'
self.assertEqual(scripts.extract_config_dir(), '/default') assert scripts.extract_config_dir() == '/default'
self.assertEqual(scripts.extract_config_dir(['']), '/default') assert scripts.extract_config_dir(['']) == '/default'
self.assertEqual(scripts.extract_config_dir(['-c', '/arg']), '/arg') assert scripts.extract_config_dir(['-c', '/arg']) == '/arg'
self.assertEqual(scripts.extract_config_dir(['--config', '/a']), '/a') assert scripts.extract_config_dir(['--config', '/a']) == '/a'

View file

@ -2,7 +2,6 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio import asyncio
import os import os
import unittest
import unittest.mock as mock import unittest.mock as mock
from collections import OrderedDict from collections import OrderedDict
from ipaddress import ip_network from ipaddress import ip_network
@ -23,7 +22,6 @@ from homeassistant.const import (
CONF_AUTH_PROVIDERS, CONF_AUTH_MFA_MODULES) CONF_AUTH_PROVIDERS, CONF_AUTH_MFA_MODULES)
from homeassistant.util import location as location_util, dt as dt_util from homeassistant.util import location as location_util, dt as dt_util
from homeassistant.util.yaml import SECRET_YAML from homeassistant.util.yaml import SECRET_YAML
from homeassistant.util.async_ import run_coroutine_threadsafe
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.components.config.group import ( from homeassistant.components.config.group import (
CONFIG_PATH as GROUP_CONFIG_PATH) CONFIG_PATH as GROUP_CONFIG_PATH)
@ -36,7 +34,7 @@ from homeassistant.components.config.customize import (
import homeassistant.scripts.check_config as check_config import homeassistant.scripts.check_config as check_config
from tests.common import ( from tests.common import (
get_test_config_dir, get_test_home_assistant, patch_yaml_files) get_test_config_dir, patch_yaml_files)
CONFIG_DIR = get_test_config_dir() CONFIG_DIR = get_test_config_dir()
YAML_PATH = os.path.join(CONFIG_DIR, config_util.YAML_CONFIG_FILE) YAML_PATH = os.path.join(CONFIG_DIR, config_util.YAML_CONFIG_FILE)
@ -55,16 +53,7 @@ def create_file(path):
pass pass
class TestConfig(unittest.TestCase): def teardown():
"""Test the configutils."""
# pylint: disable=invalid-name
def setUp(self):
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
# pylint: disable=invalid-name
def tearDown(self):
"""Clean up.""" """Clean up."""
dt_util.DEFAULT_TIME_ZONE = ORIG_TIMEZONE dt_util.DEFAULT_TIME_ZONE = ORIG_TIMEZONE
@ -89,10 +78,8 @@ class TestConfig(unittest.TestCase):
if os.path.isfile(CUSTOMIZE_PATH): if os.path.isfile(CUSTOMIZE_PATH):
os.remove(CUSTOMIZE_PATH) os.remove(CUSTOMIZE_PATH)
self.hass.stop()
# pylint: disable=no-self-use def test_create_default_config():
def test_create_default_config(self):
"""Test creation of default config.""" """Test creation of default config."""
config_util.create_default_config(CONFIG_DIR, False) config_util.create_default_config(CONFIG_DIR, False)
@ -103,14 +90,16 @@ class TestConfig(unittest.TestCase):
assert os.path.isfile(AUTOMATIONS_PATH) assert os.path.isfile(AUTOMATIONS_PATH)
assert os.path.isfile(CUSTOMIZE_PATH) assert os.path.isfile(CUSTOMIZE_PATH)
def test_find_config_file_yaml(self):
def test_find_config_file_yaml():
"""Test if it finds a YAML config file.""" """Test if it finds a YAML config file."""
create_file(YAML_PATH) create_file(YAML_PATH)
assert YAML_PATH == config_util.find_config_file(CONFIG_DIR) assert YAML_PATH == config_util.find_config_file(CONFIG_DIR)
@mock.patch('builtins.print') @mock.patch('builtins.print')
def test_ensure_config_exists_creates_config(self, mock_print): def test_ensure_config_exists_creates_config(mock_print):
"""Test that calling ensure_config_exists. """Test that calling ensure_config_exists.
If not creates a new config file. If not creates a new config file.
@ -120,7 +109,8 @@ class TestConfig(unittest.TestCase):
assert os.path.isfile(YAML_PATH) assert os.path.isfile(YAML_PATH)
assert mock_print.called assert mock_print.called
def test_ensure_config_exists_uses_existing_config(self):
def test_ensure_config_exists_uses_existing_config():
"""Test that calling ensure_config_exists uses existing config.""" """Test that calling ensure_config_exists uses existing config."""
create_file(YAML_PATH) create_file(YAML_PATH)
config_util.ensure_config_exists(CONFIG_DIR, False) config_util.ensure_config_exists(CONFIG_DIR, False)
@ -129,15 +119,17 @@ class TestConfig(unittest.TestCase):
content = f.read() content = f.read()
# File created with create_file are empty # File created with create_file are empty
assert '' == content assert content == ''
def test_load_yaml_config_converts_empty_files_to_dict(self):
def test_load_yaml_config_converts_empty_files_to_dict():
"""Test that loading an empty file returns an empty dict.""" """Test that loading an empty file returns an empty dict."""
create_file(YAML_PATH) create_file(YAML_PATH)
assert isinstance(config_util.load_yaml_config_file(YAML_PATH), dict) assert isinstance(config_util.load_yaml_config_file(YAML_PATH), dict)
def test_load_yaml_config_raises_error_if_not_dict(self):
def test_load_yaml_config_raises_error_if_not_dict():
"""Test error raised when YAML file is not a dict.""" """Test error raised when YAML file is not a dict."""
with open(YAML_PATH, 'w') as f: with open(YAML_PATH, 'w') as f:
f.write('5') f.write('5')
@ -145,7 +137,8 @@ class TestConfig(unittest.TestCase):
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
config_util.load_yaml_config_file(YAML_PATH) config_util.load_yaml_config_file(YAML_PATH)
def test_load_yaml_config_raises_error_if_malformed_yaml(self):
def test_load_yaml_config_raises_error_if_malformed_yaml():
"""Test error raised if invalid YAML.""" """Test error raised if invalid YAML."""
with open(YAML_PATH, 'w') as f: with open(YAML_PATH, 'w') as f:
f.write(':') f.write(':')
@ -153,7 +146,8 @@ class TestConfig(unittest.TestCase):
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
config_util.load_yaml_config_file(YAML_PATH) config_util.load_yaml_config_file(YAML_PATH)
def test_load_yaml_config_raises_error_if_unsafe_yaml(self):
def test_load_yaml_config_raises_error_if_unsafe_yaml():
"""Test error raised if unsafe YAML.""" """Test error raised if unsafe YAML."""
with open(YAML_PATH, 'w') as f: with open(YAML_PATH, 'w') as f:
f.write('hello: !!python/object/apply:os.system') f.write('hello: !!python/object/apply:os.system')
@ -161,7 +155,8 @@ class TestConfig(unittest.TestCase):
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
config_util.load_yaml_config_file(YAML_PATH) config_util.load_yaml_config_file(YAML_PATH)
def test_load_yaml_config_preserves_key_order(self):
def test_load_yaml_config_preserves_key_order():
"""Test removal of library.""" """Test removal of library."""
with open(YAML_PATH, 'w') as f: with open(YAML_PATH, 'w') as f:
f.write('hello: 2\n') f.write('hello: 2\n')
@ -170,6 +165,7 @@ class TestConfig(unittest.TestCase):
assert [('hello', 2), ('world', 1)] == \ assert [('hello', 2), ('world', 1)] == \
list(config_util.load_yaml_config_file(YAML_PATH).items()) list(config_util.load_yaml_config_file(YAML_PATH).items())
@mock.patch('homeassistant.util.location.detect_location_info', @mock.patch('homeassistant.util.location.detect_location_info',
return_value=location_util.LocationInfo( return_value=location_util.LocationInfo(
'0.0.0.0', 'US', 'United States', 'CA', 'California', '0.0.0.0', 'US', 'United States', 'CA', 'California',
@ -177,7 +173,7 @@ class TestConfig(unittest.TestCase):
-117.2073, True)) -117.2073, True))
@mock.patch('homeassistant.util.location.elevation', return_value=101) @mock.patch('homeassistant.util.location.elevation', return_value=101)
@mock.patch('builtins.print') @mock.patch('builtins.print')
def test_create_default_config_detect_location(self, mock_detect, def test_create_default_config_detect_location(mock_detect,
mock_elev, mock_print): mock_elev, mock_print):
"""Test that detect location sets the correct config keys.""" """Test that detect location sets the correct config keys."""
config_util.ensure_config_exists(CONFIG_DIR) config_util.ensure_config_exists(CONFIG_DIR)
@ -201,9 +197,9 @@ class TestConfig(unittest.TestCase):
assert expected_values == ha_conf assert expected_values == ha_conf
assert mock_print.called assert mock_print.called
@mock.patch('builtins.print') @mock.patch('builtins.print')
def test_create_default_config_returns_none_if_write_error(self, def test_create_default_config_returns_none_if_write_error(mock_print):
mock_print):
"""Test the writing of a default configuration. """Test the writing of a default configuration.
Non existing folder returns None. Non existing folder returns None.
@ -212,8 +208,8 @@ class TestConfig(unittest.TestCase):
os.path.join(CONFIG_DIR, 'non_existing_dir/'), False) is None os.path.join(CONFIG_DIR, 'non_existing_dir/'), False) is None
assert mock_print.called assert mock_print.called
# pylint: disable=no-self-use
def test_core_config_schema(self): def test_core_config_schema():
"""Test core config schema.""" """Test core config schema."""
for value in ( for value in (
{CONF_UNIT_SYSTEM: 'K'}, {CONF_UNIT_SYSTEM: 'K'},
@ -239,7 +235,8 @@ class TestConfig(unittest.TestCase):
}, },
}) })
def test_customize_dict_schema(self):
def test_customize_dict_schema():
"""Test basic customize config validation.""" """Test basic customize config validation."""
values = ( values = (
{ATTR_FRIENDLY_NAME: None}, {ATTR_FRIENDLY_NAME: None},
@ -262,40 +259,42 @@ class TestConfig(unittest.TestCase):
ATTR_ASSUMED_STATE: False ATTR_ASSUMED_STATE: False
} }
def test_customize_glob_is_ordered(self):
def test_customize_glob_is_ordered():
"""Test that customize_glob preserves order.""" """Test that customize_glob preserves order."""
conf = config_util.CORE_CONFIG_SCHEMA( conf = config_util.CORE_CONFIG_SCHEMA(
{'customize_glob': OrderedDict()}) {'customize_glob': OrderedDict()})
assert isinstance(conf['customize_glob'], OrderedDict) assert isinstance(conf['customize_glob'], OrderedDict)
def _compute_state(self, config):
run_coroutine_threadsafe( async def _compute_state(hass, config):
config_util.async_process_ha_core_config(self.hass, config), await config_util.async_process_ha_core_config(hass, config)
self.hass.loop).result()
entity = Entity() entity = Entity()
entity.entity_id = 'test.test' entity.entity_id = 'test.test'
entity.hass = self.hass entity.hass = hass
entity.schedule_update_ha_state() entity.schedule_update_ha_state()
self.hass.block_till_done() await hass.async_block_till_done()
return self.hass.states.get('test.test') return hass.states.get('test.test')
def test_entity_customization(self):
async def test_entity_customization(hass):
"""Test entity customization through configuration.""" """Test entity customization through configuration."""
config = {CONF_LATITUDE: 50, config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50, CONF_LONGITUDE: 50,
CONF_NAME: 'Test', CONF_NAME: 'Test',
CONF_CUSTOMIZE: {'test.test': {'hidden': True}}} CONF_CUSTOMIZE: {'test.test': {'hidden': True}}}
state = self._compute_state(config) state = await _compute_state(hass, config)
assert state.attributes['hidden'] assert state.attributes['hidden']
@mock.patch('homeassistant.config.shutil') @mock.patch('homeassistant.config.shutil')
@mock.patch('homeassistant.config.os') @mock.patch('homeassistant.config.os')
def test_remove_lib_on_upgrade(self, mock_os, mock_shutil): def test_remove_lib_on_upgrade(mock_os, mock_shutil, hass):
"""Test removal of library on upgrade from before 0.50.""" """Test removal of library on upgrade from before 0.50."""
ha_version = '0.49.0' ha_version = '0.49.0'
mock_os.path.isdir = mock.Mock(return_value=True) mock_os.path.isdir = mock.Mock(return_value=True)
@ -304,16 +303,17 @@ class TestConfig(unittest.TestCase):
opened_file = mock_open.return_value opened_file = mock_open.return_value
# pylint: disable=no-member # pylint: disable=no-member
opened_file.readline.return_value = ha_version opened_file.readline.return_value = ha_version
self.hass.config.path = mock.Mock() hass.config.path = mock.Mock()
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
hass_path = self.hass.config.path.return_value hass_path = hass.config.path.return_value
assert mock_os.path.isdir.call_count == 1 assert mock_os.path.isdir.call_count == 1
assert mock_os.path.isdir.call_args == mock.call(hass_path) assert mock_os.path.isdir.call_args == mock.call(hass_path)
assert mock_shutil.rmtree.call_count == 1 assert mock_shutil.rmtree.call_count == 1
assert mock_shutil.rmtree.call_args == mock.call(hass_path) assert mock_shutil.rmtree.call_args == mock.call(hass_path)
def test_process_config_upgrade(self):
def test_process_config_upgrade(hass):
"""Test update of version on upgrade.""" """Test update of version on upgrade."""
ha_version = '0.92.0' ha_version = '0.92.0'
@ -324,12 +324,13 @@ class TestConfig(unittest.TestCase):
# pylint: disable=no-member # pylint: disable=no-member
opened_file.readline.return_value = ha_version opened_file.readline.return_value = ha_version
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
assert opened_file.write.call_count == 1 assert opened_file.write.call_count == 1
assert opened_file.write.call_args == mock.call('0.91.0') assert opened_file.write.call_args == mock.call('0.91.0')
def test_config_upgrade_same_version(self):
def test_config_upgrade_same_version(hass):
"""Test no update of version on no upgrade.""" """Test no update of version on no upgrade."""
ha_version = __version__ ha_version = __version__
@ -339,12 +340,13 @@ class TestConfig(unittest.TestCase):
# pylint: disable=no-member # pylint: disable=no-member
opened_file.readline.return_value = ha_version opened_file.readline.return_value = ha_version
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
assert opened_file.write.call_count == 0 assert opened_file.write.call_count == 0
@mock.patch('homeassistant.config.find_config_file', mock.Mock()) @mock.patch('homeassistant.config.find_config_file', mock.Mock())
def test_config_upgrade_no_file(self): def test_config_upgrade_no_file(hass):
"""Test update of version on upgrade, with no version file.""" """Test update of version on upgrade, with no version file."""
mock_open = mock.mock_open() mock_open = mock.mock_open()
mock_open.side_effect = [FileNotFoundError(), mock_open.side_effect = [FileNotFoundError(),
@ -353,14 +355,15 @@ class TestConfig(unittest.TestCase):
with mock.patch('homeassistant.config.open', mock_open, create=True): with mock.patch('homeassistant.config.open', mock_open, create=True):
opened_file = mock_open.return_value opened_file = mock_open.return_value
# pylint: disable=no-member # pylint: disable=no-member
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
assert opened_file.write.call_count == 1 assert opened_file.write.call_count == 1
assert opened_file.write.call_args == mock.call(__version__) assert opened_file.write.call_args == mock.call(__version__)
@mock.patch('homeassistant.config.shutil') @mock.patch('homeassistant.config.shutil')
@mock.patch('homeassistant.config.os') @mock.patch('homeassistant.config.os')
@mock.patch('homeassistant.config.find_config_file', mock.Mock()) @mock.patch('homeassistant.config.find_config_file', mock.Mock())
def test_migrate_file_on_upgrade(self, mock_os, mock_shutil): def test_migrate_file_on_upgrade(mock_os, mock_shutil, hass):
"""Test migrate of config files on upgrade.""" """Test migrate of config files on upgrade."""
ha_version = '0.7.0' ha_version = '0.7.0'
@ -372,22 +375,22 @@ class TestConfig(unittest.TestCase):
return True return True
with mock.patch('homeassistant.config.open', mock_open, create=True), \ with mock.patch('homeassistant.config.open', mock_open, create=True), \
mock.patch( mock.patch('homeassistant.config.os.path.isfile', _mock_isfile):
'homeassistant.config.os.path.isfile', _mock_isfile):
opened_file = mock_open.return_value opened_file = mock_open.return_value
# pylint: disable=no-member # pylint: disable=no-member
opened_file.readline.return_value = ha_version opened_file.readline.return_value = ha_version
self.hass.config.path = mock.Mock() hass.config.path = mock.Mock()
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
assert mock_os.rename.call_count == 1 assert mock_os.rename.call_count == 1
@mock.patch('homeassistant.config.shutil') @mock.patch('homeassistant.config.shutil')
@mock.patch('homeassistant.config.os') @mock.patch('homeassistant.config.os')
@mock.patch('homeassistant.config.find_config_file', mock.Mock()) @mock.patch('homeassistant.config.find_config_file', mock.Mock())
def test_migrate_no_file_on_upgrade(self, mock_os, mock_shutil): def test_migrate_no_file_on_upgrade(mock_os, mock_shutil, hass):
"""Test not migrating config files on upgrade.""" """Test not migrating config files on upgrade."""
ha_version = '0.7.0' ha_version = '0.7.0'
@ -399,24 +402,23 @@ class TestConfig(unittest.TestCase):
return False return False
with mock.patch('homeassistant.config.open', mock_open, create=True), \ with mock.patch('homeassistant.config.open', mock_open, create=True), \
mock.patch( mock.patch('homeassistant.config.os.path.isfile', _mock_isfile):
'homeassistant.config.os.path.isfile', _mock_isfile):
opened_file = mock_open.return_value opened_file = mock_open.return_value
# pylint: disable=no-member # pylint: disable=no-member
opened_file.readline.return_value = ha_version opened_file.readline.return_value = ha_version
self.hass.config.path = mock.Mock() hass.config.path = mock.Mock()
config_util.process_ha_config_upgrade(self.hass) config_util.process_ha_config_upgrade(hass)
assert mock_os.rename.call_count == 0 assert mock_os.rename.call_count == 0
def test_loading_configuration(self):
"""Test loading core config onto hass object."""
self.hass.config = mock.Mock()
run_coroutine_threadsafe( async def test_loading_configuration(hass):
config_util.async_process_ha_core_config(self.hass, { """Test loading core config onto hass object."""
hass.config = mock.Mock()
await config_util.async_process_ha_core_config(hass, {
'latitude': 60, 'latitude': 60,
'longitude': 50, 'longitude': 50,
'elevation': 25, 'elevation': 25,
@ -424,44 +426,44 @@ class TestConfig(unittest.TestCase):
CONF_UNIT_SYSTEM: CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM: CONF_UNIT_SYSTEM_IMPERIAL,
'time_zone': 'America/New_York', 'time_zone': 'America/New_York',
'whitelist_external_dirs': '/tmp', 'whitelist_external_dirs': '/tmp',
}), self.hass.loop).result() })
assert self.hass.config.latitude == 60 assert hass.config.latitude == 60
assert self.hass.config.longitude == 50 assert hass.config.longitude == 50
assert self.hass.config.elevation == 25 assert hass.config.elevation == 25
assert self.hass.config.location_name == 'Huis' assert hass.config.location_name == 'Huis'
assert self.hass.config.units.name == CONF_UNIT_SYSTEM_IMPERIAL assert hass.config.units.name == CONF_UNIT_SYSTEM_IMPERIAL
assert self.hass.config.time_zone.zone == 'America/New_York' assert hass.config.time_zone.zone == 'America/New_York'
assert len(self.hass.config.whitelist_external_dirs) == 2 assert len(hass.config.whitelist_external_dirs) == 2
assert '/tmp' in self.hass.config.whitelist_external_dirs assert '/tmp' in hass.config.whitelist_external_dirs
def test_loading_configuration_temperature_unit(self):
async def test_loading_configuration_temperature_unit(hass):
"""Test backward compatibility when loading core config.""" """Test backward compatibility when loading core config."""
self.hass.config = mock.Mock() hass.config = mock.Mock()
run_coroutine_threadsafe( await config_util.async_process_ha_core_config(hass, {
config_util.async_process_ha_core_config(self.hass, {
'latitude': 60, 'latitude': 60,
'longitude': 50, 'longitude': 50,
'elevation': 25, 'elevation': 25,
'name': 'Huis', 'name': 'Huis',
CONF_TEMPERATURE_UNIT: 'C', CONF_TEMPERATURE_UNIT: 'C',
'time_zone': 'America/New_York', 'time_zone': 'America/New_York',
}), self.hass.loop).result() })
assert self.hass.config.latitude == 60 assert hass.config.latitude == 60
assert self.hass.config.longitude == 50 assert hass.config.longitude == 50
assert self.hass.config.elevation == 25 assert hass.config.elevation == 25
assert self.hass.config.location_name == 'Huis' assert hass.config.location_name == 'Huis'
assert self.hass.config.units.name == CONF_UNIT_SYSTEM_METRIC assert hass.config.units.name == CONF_UNIT_SYSTEM_METRIC
assert self.hass.config.time_zone.zone == 'America/New_York' assert hass.config.time_zone.zone == 'America/New_York'
def test_loading_configuration_from_packages(self):
async def test_loading_configuration_from_packages(hass):
"""Test loading packages config onto hass object config.""" """Test loading packages config onto hass object config."""
self.hass.config = mock.Mock() hass.config = mock.Mock()
run_coroutine_threadsafe( await config_util.async_process_ha_core_config(hass, {
config_util.async_process_ha_core_config(self.hass, {
'latitude': 39, 'latitude': 39,
'longitude': -1, 'longitude': -1,
'elevation': 500, 'elevation': 500,
@ -473,12 +475,11 @@ class TestConfig(unittest.TestCase):
'package_2': {'light': {'platform': 'hue'}, 'package_2': {'light': {'platform': 'hue'},
'media_extractor': None, 'media_extractor': None,
'sun': None}}, 'sun': None}},
}), self.hass.loop).result() })
# Empty packages not allowed # Empty packages not allowed
with pytest.raises(MultipleInvalid): with pytest.raises(MultipleInvalid):
run_coroutine_threadsafe( await config_util.async_process_ha_core_config(hass, {
config_util.async_process_ha_core_config(self.hass, {
'latitude': 39, 'latitude': 39,
'longitude': -1, 'longitude': -1,
'elevation': 500, 'elevation': 500,
@ -486,81 +487,74 @@ class TestConfig(unittest.TestCase):
CONF_TEMPERATURE_UNIT: 'C', CONF_TEMPERATURE_UNIT: 'C',
'time_zone': 'Europe/Madrid', 'time_zone': 'Europe/Madrid',
'packages': {'empty_package': None}, 'packages': {'empty_package': None},
}), self.hass.loop).result() })
@mock.patch('homeassistant.util.location.detect_location_info',
@asynctest.mock.patch('homeassistant.util.location.detect_location_info',
autospec=True, return_value=location_util.LocationInfo( autospec=True, return_value=location_util.LocationInfo(
'0.0.0.0', 'US', 'United States', 'CA', 'California', '0.0.0.0', 'US', 'United States', 'CA', 'California',
'San Diego', '92122', 'America/Los_Angeles', 32.8594, 'San Diego', '92122', 'America/Los_Angeles', 32.8594,
-117.2073, True)) -117.2073, True))
@mock.patch('homeassistant.util.location.elevation', @asynctest.mock.patch('homeassistant.util.location.elevation',
autospec=True, return_value=101) autospec=True, return_value=101)
def test_discovering_configuration(self, mock_detect, mock_elevation): async def test_discovering_configuration(mock_detect, mock_elevation, hass):
"""Test auto discovery for missing core configs.""" """Test auto discovery for missing core configs."""
self.hass.config.latitude = None hass.config.latitude = None
self.hass.config.longitude = None hass.config.longitude = None
self.hass.config.elevation = None hass.config.elevation = None
self.hass.config.location_name = None hass.config.location_name = None
self.hass.config.time_zone = None hass.config.time_zone = None
run_coroutine_threadsafe( await config_util.async_process_ha_core_config(hass, {})
config_util.async_process_ha_core_config(
self.hass, {}), self.hass.loop
).result()
assert self.hass.config.latitude == 32.8594 assert hass.config.latitude == 32.8594
assert self.hass.config.longitude == -117.2073 assert hass.config.longitude == -117.2073
assert self.hass.config.elevation == 101 assert hass.config.elevation == 101
assert self.hass.config.location_name == 'San Diego' assert hass.config.location_name == 'San Diego'
assert self.hass.config.units.name == CONF_UNIT_SYSTEM_METRIC assert hass.config.units.name == CONF_UNIT_SYSTEM_METRIC
assert self.hass.config.units.is_metric assert hass.config.units.is_metric
assert self.hass.config.time_zone.zone == 'America/Los_Angeles' assert hass.config.time_zone.zone == 'America/Los_Angeles'
@mock.patch('homeassistant.util.location.detect_location_info',
@asynctest.mock.patch('homeassistant.util.location.detect_location_info',
autospec=True, return_value=None) autospec=True, return_value=None)
@mock.patch('homeassistant.util.location.elevation', return_value=0) @asynctest.mock.patch('homeassistant.util.location.elevation', return_value=0)
def test_discovering_configuration_auto_detect_fails(self, mock_detect, async def test_discovering_configuration_auto_detect_fails(mock_detect,
mock_elevation): mock_elevation,
hass):
"""Test config remains unchanged if discovery fails.""" """Test config remains unchanged if discovery fails."""
self.hass.config = Config() hass.config = Config()
self.hass.config.config_dir = "/test/config" hass.config.config_dir = "/test/config"
run_coroutine_threadsafe( await config_util.async_process_ha_core_config(hass, {})
config_util.async_process_ha_core_config(
self.hass, {}), self.hass.loop
).result()
blankConfig = Config() blankConfig = Config()
assert self.hass.config.latitude == blankConfig.latitude assert hass.config.latitude == blankConfig.latitude
assert self.hass.config.longitude == blankConfig.longitude assert hass.config.longitude == blankConfig.longitude
assert self.hass.config.elevation == blankConfig.elevation assert hass.config.elevation == blankConfig.elevation
assert self.hass.config.location_name == blankConfig.location_name assert hass.config.location_name == blankConfig.location_name
assert self.hass.config.units == blankConfig.units assert hass.config.units == blankConfig.units
assert self.hass.config.time_zone == blankConfig.time_zone assert hass.config.time_zone == blankConfig.time_zone
assert len(self.hass.config.whitelist_external_dirs) == 1 assert len(hass.config.whitelist_external_dirs) == 1
assert "/test/config/www" in self.hass.config.whitelist_external_dirs assert "/test/config/www" in hass.config.whitelist_external_dirs
@asynctest.mock.patch( @asynctest.mock.patch(
'homeassistant.scripts.check_config.check_ha_config_file') 'homeassistant.scripts.check_config.check_ha_config_file')
def test_check_ha_config_file_correct(self, mock_check): async def test_check_ha_config_file_correct(mock_check, hass):
"""Check that restart propagates to stop.""" """Check that restart propagates to stop."""
mock_check.return_value = check_config.HomeAssistantConfig() mock_check.return_value = check_config.HomeAssistantConfig()
assert run_coroutine_threadsafe( assert await config_util.async_check_ha_config_file(hass) is None
config_util.async_check_ha_config_file(self.hass),
self.hass.loop
).result() is None
@asynctest.mock.patch( @asynctest.mock.patch(
'homeassistant.scripts.check_config.check_ha_config_file') 'homeassistant.scripts.check_config.check_ha_config_file')
def test_check_ha_config_file_wrong(self, mock_check): async def test_check_ha_config_file_wrong(mock_check, hass):
"""Check that restart with a bad config doesn't propagate to stop.""" """Check that restart with a bad config doesn't propagate to stop."""
mock_check.return_value = check_config.HomeAssistantConfig() mock_check.return_value = check_config.HomeAssistantConfig()
mock_check.return_value.add_error("bad") mock_check.return_value.add_error("bad")
assert run_coroutine_threadsafe( assert await config_util.async_check_ha_config_file(hass) == 'bad'
config_util.async_check_ha_config_file(self.hass),
self.hass.loop
).result() == 'bad'
@asynctest.mock.patch('homeassistant.config.os.path.isfile', @asynctest.mock.patch('homeassistant.config.os.path.isfile',

View file

@ -1,10 +1,9 @@
"""Test Home Assistant color util methods.""" """Test Home Assistant color util methods."""
import unittest
import homeassistant.util.color as color_util
import pytest import pytest
import voluptuous as vol import voluptuous as vol
import homeassistant.util.color as color_util
GAMUT = color_util.GamutType(color_util.XYPoint(0.704, 0.296), GAMUT = color_util.GamutType(color_util.XYPoint(0.704, 0.296),
color_util.XYPoint(0.2151, 0.7106), color_util.XYPoint(0.2151, 0.7106),
color_util.XYPoint(0.138, 0.08)) color_util.XYPoint(0.138, 0.08))
@ -22,228 +21,180 @@ GAMUT_INVALID_4 = color_util.GamutType(color_util.XYPoint(0.1, 0.1),
color_util.XYPoint(0.7, 0.7)) color_util.XYPoint(0.7, 0.7))
class TestColorUtil(unittest.TestCase):
"""Test color util methods."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
def test_color_RGB_to_xy_brightness(self): def test_color_RGB_to_xy_brightness():
"""Test color_RGB_to_xy_brightness.""" """Test color_RGB_to_xy_brightness."""
assert (0, 0, 0) == \ assert color_util.color_RGB_to_xy_brightness(0, 0, 0) == (0, 0, 0)
color_util.color_RGB_to_xy_brightness(0, 0, 0) assert color_util.color_RGB_to_xy_brightness(255, 255, 255) == \
assert (0.323, 0.329, 255) == \ (0.323, 0.329, 255)
color_util.color_RGB_to_xy_brightness(255, 255, 255)
assert (0.136, 0.04, 12) == \ assert color_util.color_RGB_to_xy_brightness(0, 0, 255) == \
color_util.color_RGB_to_xy_brightness(0, 0, 255) (0.136, 0.04, 12)
assert (0.172, 0.747, 170) == \ assert color_util.color_RGB_to_xy_brightness(0, 255, 0) == \
color_util.color_RGB_to_xy_brightness(0, 255, 0) (0.172, 0.747, 170)
assert (0.701, 0.299, 72) == \ assert color_util.color_RGB_to_xy_brightness(255, 0, 0) == \
color_util.color_RGB_to_xy_brightness(255, 0, 0) (0.701, 0.299, 72)
assert (0.701, 0.299, 16) == \ assert color_util.color_RGB_to_xy_brightness(128, 0, 0) == \
color_util.color_RGB_to_xy_brightness(128, 0, 0) (0.701, 0.299, 16)
assert (0.7, 0.299, 72) == \ assert color_util.color_RGB_to_xy_brightness(255, 0, 0, GAMUT) == \
color_util.color_RGB_to_xy_brightness(255, 0, 0, GAMUT) (0.7, 0.299, 72)
assert (0.215, 0.711, 170) == \ assert color_util.color_RGB_to_xy_brightness(0, 255, 0, GAMUT) == \
color_util.color_RGB_to_xy_brightness(0, 255, 0, GAMUT) (0.215, 0.711, 170)
assert (0.138, 0.08, 12) == \ assert color_util.color_RGB_to_xy_brightness(0, 0, 255, GAMUT) == \
color_util.color_RGB_to_xy_brightness(0, 0, 255, GAMUT) (0.138, 0.08, 12)
def test_color_RGB_to_xy(self):
def test_color_RGB_to_xy():
"""Test color_RGB_to_xy.""" """Test color_RGB_to_xy."""
assert (0, 0) == \ assert color_util.color_RGB_to_xy(0, 0, 0) == (0, 0)
color_util.color_RGB_to_xy(0, 0, 0) assert color_util.color_RGB_to_xy(255, 255, 255) == (0.323, 0.329)
assert (0.323, 0.329) == \
color_util.color_RGB_to_xy(255, 255, 255)
assert (0.136, 0.04) == \ assert color_util.color_RGB_to_xy(0, 0, 255) == (0.136, 0.04)
color_util.color_RGB_to_xy(0, 0, 255)
assert (0.172, 0.747) == \ assert color_util.color_RGB_to_xy(0, 255, 0) == (0.172, 0.747)
color_util.color_RGB_to_xy(0, 255, 0)
assert (0.701, 0.299) == \ assert color_util.color_RGB_to_xy(255, 0, 0) == (0.701, 0.299)
color_util.color_RGB_to_xy(255, 0, 0)
assert (0.701, 0.299) == \ assert color_util.color_RGB_to_xy(128, 0, 0) == (0.701, 0.299)
color_util.color_RGB_to_xy(128, 0, 0)
assert (0.138, 0.08) == \ assert color_util.color_RGB_to_xy(0, 0, 255, GAMUT) == (0.138, 0.08)
color_util.color_RGB_to_xy(0, 0, 255, GAMUT)
assert (0.215, 0.711) == \ assert color_util.color_RGB_to_xy(0, 255, 0, GAMUT) == (0.215, 0.711)
color_util.color_RGB_to_xy(0, 255, 0, GAMUT)
assert (0.7, 0.299) == \ assert color_util.color_RGB_to_xy(255, 0, 0, GAMUT) == (0.7, 0.299)
color_util.color_RGB_to_xy(255, 0, 0, GAMUT)
def test_color_xy_brightness_to_RGB(self):
def test_color_xy_brightness_to_RGB():
"""Test color_xy_brightness_to_RGB.""" """Test color_xy_brightness_to_RGB."""
assert (0, 0, 0) == \ assert color_util.color_xy_brightness_to_RGB(1, 1, 0) == (0, 0, 0)
color_util.color_xy_brightness_to_RGB(1, 1, 0)
assert (194, 186, 169) == \ assert color_util.color_xy_brightness_to_RGB(.35, .35, 128) == \
color_util.color_xy_brightness_to_RGB(.35, .35, 128) (194, 186, 169)
assert (255, 243, 222) == \ assert color_util.color_xy_brightness_to_RGB(.35, .35, 255) == \
color_util.color_xy_brightness_to_RGB(.35, .35, 255) (255, 243, 222)
assert (255, 0, 60) == \ assert color_util.color_xy_brightness_to_RGB(1, 0, 255) == (255, 0, 60)
color_util.color_xy_brightness_to_RGB(1, 0, 255)
assert (0, 255, 0) == \ assert color_util.color_xy_brightness_to_RGB(0, 1, 255) == (0, 255, 0)
color_util.color_xy_brightness_to_RGB(0, 1, 255)
assert (0, 63, 255) == \ assert color_util.color_xy_brightness_to_RGB(0, 0, 255) == (0, 63, 255)
color_util.color_xy_brightness_to_RGB(0, 0, 255)
assert (255, 0, 3) == \ assert color_util.color_xy_brightness_to_RGB(1, 0, 255, GAMUT) == \
color_util.color_xy_brightness_to_RGB(1, 0, 255, GAMUT) (255, 0, 3)
assert (82, 255, 0) == \ assert color_util.color_xy_brightness_to_RGB(0, 1, 255, GAMUT) == \
color_util.color_xy_brightness_to_RGB(0, 1, 255, GAMUT) (82, 255, 0)
assert (9, 85, 255) == \ assert color_util.color_xy_brightness_to_RGB(0, 0, 255, GAMUT) == \
color_util.color_xy_brightness_to_RGB(0, 0, 255, GAMUT) (9, 85, 255)
def test_color_xy_to_RGB(self):
def test_color_xy_to_RGB():
"""Test color_xy_to_RGB.""" """Test color_xy_to_RGB."""
assert (255, 243, 222) == \ assert color_util.color_xy_to_RGB(.35, .35) == (255, 243, 222)
color_util.color_xy_to_RGB(.35, .35)
assert (255, 0, 60) == \ assert color_util.color_xy_to_RGB(1, 0) == (255, 0, 60)
color_util.color_xy_to_RGB(1, 0)
assert (0, 255, 0) == \ assert color_util.color_xy_to_RGB(0, 1) == (0, 255, 0)
color_util.color_xy_to_RGB(0, 1)
assert (0, 63, 255) == \ assert color_util.color_xy_to_RGB(0, 0) == (0, 63, 255)
color_util.color_xy_to_RGB(0, 0)
assert (255, 0, 3) == \ assert color_util.color_xy_to_RGB(1, 0, GAMUT) == (255, 0, 3)
color_util.color_xy_to_RGB(1, 0, GAMUT)
assert (82, 255, 0) == \ assert color_util.color_xy_to_RGB(0, 1, GAMUT) == (82, 255, 0)
color_util.color_xy_to_RGB(0, 1, GAMUT)
assert (9, 85, 255) == \ assert color_util.color_xy_to_RGB(0, 0, GAMUT) == (9, 85, 255)
color_util.color_xy_to_RGB(0, 0, GAMUT)
def test_color_RGB_to_hsv(self):
def test_color_RGB_to_hsv():
"""Test color_RGB_to_hsv.""" """Test color_RGB_to_hsv."""
assert (0, 0, 0) == \ assert color_util.color_RGB_to_hsv(0, 0, 0) == (0, 0, 0)
color_util.color_RGB_to_hsv(0, 0, 0)
assert (0, 0, 100) == \ assert color_util.color_RGB_to_hsv(255, 255, 255) == (0, 0, 100)
color_util.color_RGB_to_hsv(255, 255, 255)
assert (240, 100, 100) == \ assert color_util.color_RGB_to_hsv(0, 0, 255) == (240, 100, 100)
color_util.color_RGB_to_hsv(0, 0, 255)
assert (120, 100, 100) == \ assert color_util.color_RGB_to_hsv(0, 255, 0) == (120, 100, 100)
color_util.color_RGB_to_hsv(0, 255, 0)
assert (0, 100, 100) == \ assert color_util.color_RGB_to_hsv(255, 0, 0) == (0, 100, 100)
color_util.color_RGB_to_hsv(255, 0, 0)
def test_color_hsv_to_RGB(self):
def test_color_hsv_to_RGB():
"""Test color_hsv_to_RGB.""" """Test color_hsv_to_RGB."""
assert (0, 0, 0) == \ assert color_util.color_hsv_to_RGB(0, 0, 0) == (0, 0, 0)
color_util.color_hsv_to_RGB(0, 0, 0)
assert (255, 255, 255) == \ assert color_util.color_hsv_to_RGB(0, 0, 100) == (255, 255, 255)
color_util.color_hsv_to_RGB(0, 0, 100)
assert (0, 0, 255) == \ assert color_util.color_hsv_to_RGB(240, 100, 100) == (0, 0, 255)
color_util.color_hsv_to_RGB(240, 100, 100)
assert (0, 255, 0) == \ assert color_util.color_hsv_to_RGB(120, 100, 100) == (0, 255, 0)
color_util.color_hsv_to_RGB(120, 100, 100)
assert (255, 0, 0) == \ assert color_util.color_hsv_to_RGB(0, 100, 100) == (255, 0, 0)
color_util.color_hsv_to_RGB(0, 100, 100)
def test_color_hsb_to_RGB(self):
def test_color_hsb_to_RGB():
"""Test color_hsb_to_RGB.""" """Test color_hsb_to_RGB."""
assert (0, 0, 0) == \ assert color_util.color_hsb_to_RGB(0, 0, 0) == (0, 0, 0)
color_util.color_hsb_to_RGB(0, 0, 0)
assert (255, 255, 255) == \ assert color_util.color_hsb_to_RGB(0, 0, 1.0) == (255, 255, 255)
color_util.color_hsb_to_RGB(0, 0, 1.0)
assert (0, 0, 255) == \ assert color_util.color_hsb_to_RGB(240, 1.0, 1.0) == (0, 0, 255)
color_util.color_hsb_to_RGB(240, 1.0, 1.0)
assert (0, 255, 0) == \ assert color_util.color_hsb_to_RGB(120, 1.0, 1.0) == (0, 255, 0)
color_util.color_hsb_to_RGB(120, 1.0, 1.0)
assert (255, 0, 0) == \ assert color_util.color_hsb_to_RGB(0, 1.0, 1.0) == (255, 0, 0)
color_util.color_hsb_to_RGB(0, 1.0, 1.0)
def test_color_xy_to_hs(self):
def test_color_xy_to_hs():
"""Test color_xy_to_hs.""" """Test color_xy_to_hs."""
assert (47.294, 100) == \ assert color_util.color_xy_to_hs(1, 1) == (47.294, 100)
color_util.color_xy_to_hs(1, 1)
assert (38.182, 12.941) == \ assert color_util.color_xy_to_hs(.35, .35) == (38.182, 12.941)
color_util.color_xy_to_hs(.35, .35)
assert (345.882, 100) == \ assert color_util.color_xy_to_hs(1, 0) == (345.882, 100)
color_util.color_xy_to_hs(1, 0)
assert (120, 100) == \ assert color_util.color_xy_to_hs(0, 1) == (120, 100)
color_util.color_xy_to_hs(0, 1)
assert (225.176, 100) == \ assert color_util.color_xy_to_hs(0, 0) == (225.176, 100)
color_util.color_xy_to_hs(0, 0)
assert (359.294, 100) == \ assert color_util.color_xy_to_hs(1, 0, GAMUT) == (359.294, 100)
color_util.color_xy_to_hs(1, 0, GAMUT)
assert (100.706, 100) == \ assert color_util.color_xy_to_hs(0, 1, GAMUT) == (100.706, 100)
color_util.color_xy_to_hs(0, 1, GAMUT)
assert (221.463, 96.471) == \ assert color_util.color_xy_to_hs(0, 0, GAMUT) == (221.463, 96.471)
color_util.color_xy_to_hs(0, 0, GAMUT)
def test_color_hs_to_xy(self):
def test_color_hs_to_xy():
"""Test color_hs_to_xy.""" """Test color_hs_to_xy."""
assert (0.151, 0.343) == \ assert color_util.color_hs_to_xy(180, 100) == (0.151, 0.343)
color_util.color_hs_to_xy(180, 100)
assert (0.356, 0.321) == \ assert color_util.color_hs_to_xy(350, 12.5) == (0.356, 0.321)
color_util.color_hs_to_xy(350, 12.5)
assert (0.229, 0.474) == \ assert color_util.color_hs_to_xy(140, 50) == (0.229, 0.474)
color_util.color_hs_to_xy(140, 50)
assert (0.474, 0.317) == \ assert color_util.color_hs_to_xy(0, 40) == (0.474, 0.317)
color_util.color_hs_to_xy(0, 40)
assert (0.323, 0.329) == \ assert color_util.color_hs_to_xy(360, 0) == (0.323, 0.329)
color_util.color_hs_to_xy(360, 0)
assert (0.7, 0.299) == \ assert color_util.color_hs_to_xy(0, 100, GAMUT) == (0.7, 0.299)
color_util.color_hs_to_xy(0, 100, GAMUT)
assert (0.215, 0.711) == \ assert color_util.color_hs_to_xy(120, 100, GAMUT) == (0.215, 0.711)
color_util.color_hs_to_xy(120, 100, GAMUT)
assert (0.17, 0.34) == \ assert color_util.color_hs_to_xy(180, 100, GAMUT) == (0.17, 0.34)
color_util.color_hs_to_xy(180, 100, GAMUT)
assert (0.138, 0.08) == \ assert color_util.color_hs_to_xy(240, 100, GAMUT) == (0.138, 0.08)
color_util.color_hs_to_xy(240, 100, GAMUT)
assert (0.7, 0.299) == \ assert color_util.color_hs_to_xy(360, 100, GAMUT) == (0.7, 0.299)
color_util.color_hs_to_xy(360, 100, GAMUT)
def test_rgb_hex_to_rgb_list(self):
def test_rgb_hex_to_rgb_list():
"""Test rgb_hex_to_rgb_list.""" """Test rgb_hex_to_rgb_list."""
assert [255, 255, 255] == \ assert [255, 255, 255] == \
color_util.rgb_hex_to_rgb_list('ffffff') color_util.rgb_hex_to_rgb_list('ffffff')
@ -263,94 +214,77 @@ class TestColorUtil(unittest.TestCase):
assert [51, 153, 255, 0] == \ assert [51, 153, 255, 0] == \
color_util.rgb_hex_to_rgb_list('3399ff00') color_util.rgb_hex_to_rgb_list('3399ff00')
def test_color_name_to_rgb_valid_name(self):
def test_color_name_to_rgb_valid_name():
"""Test color_name_to_rgb.""" """Test color_name_to_rgb."""
assert (255, 0, 0) == \ assert color_util.color_name_to_rgb('red') == (255, 0, 0)
color_util.color_name_to_rgb('red')
assert (0, 0, 255) == \ assert color_util.color_name_to_rgb('blue') == (0, 0, 255)
color_util.color_name_to_rgb('blue')
assert (0, 128, 0) == \ assert color_util.color_name_to_rgb('green') == (0, 128, 0)
color_util.color_name_to_rgb('green')
# spaces in the name # spaces in the name
assert (72, 61, 139) == \ assert color_util.color_name_to_rgb('dark slate blue') == (72, 61, 139)
color_util.color_name_to_rgb('dark slate blue')
# spaces removed from name # spaces removed from name
assert (72, 61, 139) == \ assert color_util.color_name_to_rgb('darkslateblue') == (72, 61, 139)
color_util.color_name_to_rgb('darkslateblue') assert color_util.color_name_to_rgb('dark slateblue') == (72, 61, 139)
assert (72, 61, 139) == \ assert color_util.color_name_to_rgb('darkslate blue') == (72, 61, 139)
color_util.color_name_to_rgb('dark slateblue')
assert (72, 61, 139) == \
color_util.color_name_to_rgb('darkslate blue')
def test_color_name_to_rgb_unknown_name_raises_value_error(self):
def test_color_name_to_rgb_unknown_name_raises_value_error():
"""Test color_name_to_rgb.""" """Test color_name_to_rgb."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
color_util.color_name_to_rgb('not a color') color_util.color_name_to_rgb('not a color')
def test_color_rgb_to_rgbw(self):
def test_color_rgb_to_rgbw():
"""Test color_rgb_to_rgbw.""" """Test color_rgb_to_rgbw."""
assert (0, 0, 0, 0) == \ assert color_util.color_rgb_to_rgbw(0, 0, 0) == (0, 0, 0, 0)
color_util.color_rgb_to_rgbw(0, 0, 0)
assert (0, 0, 0, 255) == \ assert color_util.color_rgb_to_rgbw(255, 255, 255) == (0, 0, 0, 255)
color_util.color_rgb_to_rgbw(255, 255, 255)
assert (255, 0, 0, 0) == \ assert color_util.color_rgb_to_rgbw(255, 0, 0) == (255, 0, 0, 0)
color_util.color_rgb_to_rgbw(255, 0, 0)
assert (0, 255, 0, 0) == \ assert color_util.color_rgb_to_rgbw(0, 255, 0) == (0, 255, 0, 0)
color_util.color_rgb_to_rgbw(0, 255, 0)
assert (0, 0, 255, 0) == \ assert color_util.color_rgb_to_rgbw(0, 0, 255) == (0, 0, 255, 0)
color_util.color_rgb_to_rgbw(0, 0, 255)
assert (255, 127, 0, 0) == \ assert color_util.color_rgb_to_rgbw(255, 127, 0) == (255, 127, 0, 0)
color_util.color_rgb_to_rgbw(255, 127, 0)
assert (255, 0, 0, 253) == \ assert color_util.color_rgb_to_rgbw(255, 127, 127) == (255, 0, 0, 253)
color_util.color_rgb_to_rgbw(255, 127, 127)
assert (0, 0, 0, 127) == \ assert color_util.color_rgb_to_rgbw(127, 127, 127) == (0, 0, 0, 127)
color_util.color_rgb_to_rgbw(127, 127, 127)
def test_color_rgbw_to_rgb(self):
def test_color_rgbw_to_rgb():
"""Test color_rgbw_to_rgb.""" """Test color_rgbw_to_rgb."""
assert (0, 0, 0) == \ assert color_util.color_rgbw_to_rgb(0, 0, 0, 0) == (0, 0, 0)
color_util.color_rgbw_to_rgb(0, 0, 0, 0)
assert (255, 255, 255) == \ assert color_util.color_rgbw_to_rgb(0, 0, 0, 255) == (255, 255, 255)
color_util.color_rgbw_to_rgb(0, 0, 0, 255)
assert (255, 0, 0) == \ assert color_util.color_rgbw_to_rgb(255, 0, 0, 0) == (255, 0, 0)
color_util.color_rgbw_to_rgb(255, 0, 0, 0)
assert (0, 255, 0) == \ assert color_util.color_rgbw_to_rgb(0, 255, 0, 0) == (0, 255, 0)
color_util.color_rgbw_to_rgb(0, 255, 0, 0)
assert (0, 0, 255) == \ assert color_util.color_rgbw_to_rgb(0, 0, 255, 0) == (0, 0, 255)
color_util.color_rgbw_to_rgb(0, 0, 255, 0)
assert (255, 127, 0) == \ assert color_util.color_rgbw_to_rgb(255, 127, 0, 0) == (255, 127, 0)
color_util.color_rgbw_to_rgb(255, 127, 0, 0)
assert (255, 127, 127) == \ assert color_util.color_rgbw_to_rgb(255, 0, 0, 253) == (255, 127, 127)
color_util.color_rgbw_to_rgb(255, 0, 0, 253)
assert (127, 127, 127) == \ assert color_util.color_rgbw_to_rgb(0, 0, 0, 127) == (127, 127, 127)
color_util.color_rgbw_to_rgb(0, 0, 0, 127)
def test_color_rgb_to_hex(self):
def test_color_rgb_to_hex():
"""Test color_rgb_to_hex.""" """Test color_rgb_to_hex."""
assert color_util.color_rgb_to_hex(255, 255, 255) == 'ffffff' assert color_util.color_rgb_to_hex(255, 255, 255) == 'ffffff'
assert color_util.color_rgb_to_hex(0, 0, 0) == '000000' assert color_util.color_rgb_to_hex(0, 0, 0) == '000000'
assert color_util.color_rgb_to_hex(51, 153, 255) == '3399ff' assert color_util.color_rgb_to_hex(51, 153, 255) == '3399ff'
assert color_util.color_rgb_to_hex(255, 67.9204190, 0) == 'ff4400' assert color_util.color_rgb_to_hex(255, 67.9204190, 0) == 'ff4400'
def test_gamut(self):
def test_gamut():
"""Test gamut functions.""" """Test gamut functions."""
assert color_util.check_valid_gamut(GAMUT) assert color_util.check_valid_gamut(GAMUT)
assert not color_util.check_valid_gamut(GAMUT_INVALID_1) assert not color_util.check_valid_gamut(GAMUT_INVALID_1)
@ -359,50 +293,45 @@ class TestColorUtil(unittest.TestCase):
assert not color_util.check_valid_gamut(GAMUT_INVALID_4) assert not color_util.check_valid_gamut(GAMUT_INVALID_4)
class ColorTemperatureMiredToKelvinTests(unittest.TestCase): def test_should_return_25000_kelvin_when_input_is_40_mired():
"""Test color_temperature_mired_to_kelvin."""
def test_should_return_25000_kelvin_when_input_is_40_mired(self):
"""Function should return 25000K if given 40 mired.""" """Function should return 25000K if given 40 mired."""
kelvin = color_util.color_temperature_mired_to_kelvin(40) kelvin = color_util.color_temperature_mired_to_kelvin(40)
assert 25000 == kelvin assert kelvin == 25000
def test_should_return_5000_kelvin_when_input_is_200_mired(self):
def test_should_return_5000_kelvin_when_input_is_200_mired():
"""Function should return 5000K if given 200 mired.""" """Function should return 5000K if given 200 mired."""
kelvin = color_util.color_temperature_mired_to_kelvin(200) kelvin = color_util.color_temperature_mired_to_kelvin(200)
assert 5000 == kelvin assert kelvin == 5000
class ColorTemperatureKelvinToMiredTests(unittest.TestCase): def test_should_return_40_mired_when_input_is_25000_kelvin():
"""Test color_temperature_kelvin_to_mired."""
def test_should_return_40_mired_when_input_is_25000_kelvin(self):
"""Function should return 40 mired when given 25000 Kelvin.""" """Function should return 40 mired when given 25000 Kelvin."""
mired = color_util.color_temperature_kelvin_to_mired(25000) mired = color_util.color_temperature_kelvin_to_mired(25000)
assert 40 == mired assert mired == 40
def test_should_return_200_mired_when_input_is_5000_kelvin(self):
def test_should_return_200_mired_when_input_is_5000_kelvin():
"""Function should return 200 mired when given 5000 Kelvin.""" """Function should return 200 mired when given 5000 Kelvin."""
mired = color_util.color_temperature_kelvin_to_mired(5000) mired = color_util.color_temperature_kelvin_to_mired(5000)
assert 200 == mired assert mired == 200
class ColorTemperatureToRGB(unittest.TestCase): def test_returns_same_value_for_any_two_temperatures_below_1000():
"""Test color_temperature_to_rgb."""
def test_returns_same_value_for_any_two_temperatures_below_1000(self):
"""Function should return same value for 999 Kelvin and 0 Kelvin.""" """Function should return same value for 999 Kelvin and 0 Kelvin."""
rgb_1 = color_util.color_temperature_to_rgb(999) rgb_1 = color_util.color_temperature_to_rgb(999)
rgb_2 = color_util.color_temperature_to_rgb(0) rgb_2 = color_util.color_temperature_to_rgb(0)
assert rgb_1 == rgb_2 assert rgb_1 == rgb_2
def test_returns_same_value_for_any_two_temperatures_above_40000(self):
def test_returns_same_value_for_any_two_temperatures_above_40000():
"""Function should return same value for 40001K and 999999K.""" """Function should return same value for 40001K and 999999K."""
rgb_1 = color_util.color_temperature_to_rgb(40001) rgb_1 = color_util.color_temperature_to_rgb(40001)
rgb_2 = color_util.color_temperature_to_rgb(999999) rgb_2 = color_util.color_temperature_to_rgb(999999)
assert rgb_1 == rgb_2 assert rgb_1 == rgb_2
def test_should_return_pure_white_at_6600(self):
def test_should_return_pure_white_at_6600():
""" """
Function should return red=255, blue=255, green=255 when given 6600K. Function should return red=255, blue=255, green=255 when given 6600K.
@ -413,13 +342,15 @@ class ColorTemperatureToRGB(unittest.TestCase):
rgb = color_util.color_temperature_to_rgb(6600) rgb = color_util.color_temperature_to_rgb(6600)
assert (255, 255, 255) == rgb assert (255, 255, 255) == rgb
def test_color_above_6600_should_have_more_blue_than_red_or_green(self):
def test_color_above_6600_should_have_more_blue_than_red_or_green():
"""Function should return a higher blue value for blue-ish light.""" """Function should return a higher blue value for blue-ish light."""
rgb = color_util.color_temperature_to_rgb(6700) rgb = color_util.color_temperature_to_rgb(6700)
assert rgb[2] > rgb[1] assert rgb[2] > rgb[1]
assert rgb[2] > rgb[0] assert rgb[2] > rgb[0]
def test_color_below_6600_should_have_more_red_than_blue_or_green(self):
def test_color_below_6600_should_have_more_red_than_blue_or_green():
"""Function should return a higher red value for red-ish light.""" """Function should return a higher red value for red-ish light."""
rgb = color_util.color_temperature_to_rgb(6500) rgb = color_util.color_temperature_to_rgb(6500)
assert rgb[0] > rgb[1] assert rgb[0] > rgb[1]

View file

@ -1,79 +1,68 @@
"""Test homeassistant distance utility functions.""" """Test homeassistant distance utility functions."""
import unittest import pytest
import homeassistant.util.distance as distance_util import homeassistant.util.distance as distance_util
from homeassistant.const import (LENGTH_KILOMETERS, LENGTH_METERS, LENGTH_FEET, from homeassistant.const import (LENGTH_KILOMETERS, LENGTH_METERS, LENGTH_FEET,
LENGTH_MILES) LENGTH_MILES)
import pytest
INVALID_SYMBOL = 'bob' INVALID_SYMBOL = 'bob'
VALID_SYMBOL = LENGTH_KILOMETERS VALID_SYMBOL = LENGTH_KILOMETERS
class TestDistanceUtil(unittest.TestCase): def test_convert_same_unit():
"""Test the distance utility functions."""
def test_convert_same_unit(self):
"""Test conversion from any unit to same unit.""" """Test conversion from any unit to same unit."""
assert 5 == distance_util.convert(5, LENGTH_KILOMETERS, assert distance_util.convert(5, LENGTH_KILOMETERS, LENGTH_KILOMETERS) == 5
LENGTH_KILOMETERS) assert distance_util.convert(2, LENGTH_METERS, LENGTH_METERS) == 2
assert 2 == distance_util.convert(2, LENGTH_METERS, assert distance_util.convert(10, LENGTH_MILES, LENGTH_MILES) == 10
LENGTH_METERS) assert distance_util.convert(9, LENGTH_FEET, LENGTH_FEET) == 9
assert 10 == distance_util.convert(10, LENGTH_MILES, LENGTH_MILES)
assert 9 == distance_util.convert(9, LENGTH_FEET, LENGTH_FEET)
def test_convert_invalid_unit(self):
def test_convert_invalid_unit():
"""Test exception is thrown for invalid units.""" """Test exception is thrown for invalid units."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
distance_util.convert(5, INVALID_SYMBOL, distance_util.convert(5, INVALID_SYMBOL, VALID_SYMBOL)
VALID_SYMBOL)
with pytest.raises(ValueError): with pytest.raises(ValueError):
distance_util.convert(5, VALID_SYMBOL, distance_util.convert(5, VALID_SYMBOL, INVALID_SYMBOL)
INVALID_SYMBOL)
def test_convert_nonnumeric_value(self):
def test_convert_nonnumeric_value():
"""Test exception is thrown for nonnumeric type.""" """Test exception is thrown for nonnumeric type."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
distance_util.convert('a', LENGTH_KILOMETERS, LENGTH_METERS) distance_util.convert('a', LENGTH_KILOMETERS, LENGTH_METERS)
def test_convert_from_miles(self):
def test_convert_from_miles():
"""Test conversion from miles to other units.""" """Test conversion from miles to other units."""
miles = 5 miles = 5
assert distance_util.convert( assert distance_util.convert(miles, LENGTH_MILES, LENGTH_KILOMETERS) == \
miles, LENGTH_MILES, LENGTH_KILOMETERS 8.04672
) == 8.04672 assert distance_util.convert(miles, LENGTH_MILES, LENGTH_METERS) == 8046.72
assert distance_util.convert(miles, LENGTH_MILES, LENGTH_METERS) == \
8046.72
assert distance_util.convert(miles, LENGTH_MILES, LENGTH_FEET) == \ assert distance_util.convert(miles, LENGTH_MILES, LENGTH_FEET) == \
26400.0008448 26400.0008448
def test_convert_from_feet(self):
def test_convert_from_feet():
"""Test conversion from feet to other units.""" """Test conversion from feet to other units."""
feet = 5000 feet = 5000
assert distance_util.convert(feet, LENGTH_FEET, LENGTH_KILOMETERS) == \ assert distance_util.convert(feet, LENGTH_FEET, LENGTH_KILOMETERS) == 1.524
1.524 assert distance_util.convert(feet, LENGTH_FEET, LENGTH_METERS) == 1524
assert distance_util.convert(feet, LENGTH_FEET, LENGTH_METERS) == \
1524
assert distance_util.convert(feet, LENGTH_FEET, LENGTH_MILES) == \ assert distance_util.convert(feet, LENGTH_FEET, LENGTH_MILES) == \
0.9469694040000001 0.9469694040000001
def test_convert_from_kilometers(self):
def test_convert_from_kilometers():
"""Test conversion from kilometers to other units.""" """Test conversion from kilometers to other units."""
km = 5 km = 5
assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_FEET) == \ assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_FEET) == 16404.2
16404.2 assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_METERS) == 5000
assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_METERS) == \
5000
assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_MILES) == \ assert distance_util.convert(km, LENGTH_KILOMETERS, LENGTH_MILES) == \
3.106855 3.106855
def test_convert_from_meters(self):
def test_convert_from_meters():
"""Test conversion from meters to other units.""" """Test conversion from meters to other units."""
m = 5000 m = 5000
assert distance_util.convert(m, LENGTH_METERS, LENGTH_FEET) == \ assert distance_util.convert(m, LENGTH_METERS, LENGTH_FEET) == 16404.2
16404.2 assert distance_util.convert(m, LENGTH_METERS, LENGTH_KILOMETERS) == 5
assert distance_util.convert(m, LENGTH_METERS, LENGTH_KILOMETERS) == \ assert distance_util.convert(m, LENGTH_METERS, LENGTH_MILES) == 3.106855
5
assert distance_util.convert(m, LENGTH_METERS, LENGTH_MILES) == \
3.106855

View file

@ -1,38 +1,35 @@
"""Test Home Assistant date util methods.""" """Test Home Assistant date util methods."""
import unittest
from datetime import datetime, timedelta from datetime import datetime, timedelta
import homeassistant.util.dt as dt_util
import pytest import pytest
import homeassistant.util.dt as dt_util
DEFAULT_TIME_ZONE = dt_util.DEFAULT_TIME_ZONE
TEST_TIME_ZONE = 'America/Los_Angeles' TEST_TIME_ZONE = 'America/Los_Angeles'
class TestDateUtil(unittest.TestCase): def teardown():
"""Test util date methods."""
def setUp(self):
"""Set up the tests."""
self.orig_default_time_zone = dt_util.DEFAULT_TIME_ZONE
def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""
dt_util.set_default_time_zone(self.orig_default_time_zone) dt_util.set_default_time_zone(DEFAULT_TIME_ZONE)
def test_get_time_zone_retrieves_valid_time_zone(self):
def test_get_time_zone_retrieves_valid_time_zone():
"""Test getting a time zone.""" """Test getting a time zone."""
time_zone = dt_util.get_time_zone(TEST_TIME_ZONE) time_zone = dt_util.get_time_zone(TEST_TIME_ZONE)
assert time_zone is not None assert time_zone is not None
assert TEST_TIME_ZONE == time_zone.zone assert TEST_TIME_ZONE == time_zone.zone
def test_get_time_zone_returns_none_for_garbage_time_zone(self):
def test_get_time_zone_returns_none_for_garbage_time_zone():
"""Test getting a non existing time zone.""" """Test getting a non existing time zone."""
time_zone = dt_util.get_time_zone("Non existing time zone") time_zone = dt_util.get_time_zone("Non existing time zone")
assert time_zone is None assert time_zone is None
def test_set_default_time_zone(self):
def test_set_default_time_zone():
"""Test setting default time zone.""" """Test setting default time zone."""
time_zone = dt_util.get_time_zone(TEST_TIME_ZONE) time_zone = dt_util.get_time_zone(TEST_TIME_ZONE)
@ -41,12 +38,14 @@ class TestDateUtil(unittest.TestCase):
# We cannot compare the timezones directly because of DST # We cannot compare the timezones directly because of DST
assert time_zone.zone == dt_util.now().tzinfo.zone assert time_zone.zone == dt_util.now().tzinfo.zone
def test_utcnow(self):
def test_utcnow():
"""Test the UTC now method.""" """Test the UTC now method."""
assert abs(dt_util.utcnow().replace(tzinfo=None)-datetime.utcnow()) < \ assert abs(dt_util.utcnow().replace(tzinfo=None)-datetime.utcnow()) < \
timedelta(seconds=1) timedelta(seconds=1)
def test_now(self):
def test_now():
"""Test the now method.""" """Test the now method."""
dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE)) dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE))
@ -56,19 +55,22 @@ class TestDateUtil(unittest.TestCase):
) - datetime.utcnow() ) - datetime.utcnow()
) < timedelta(seconds=1) ) < timedelta(seconds=1)
def test_as_utc_with_naive_object(self):
def test_as_utc_with_naive_object():
"""Test the now method.""" """Test the now method."""
utcnow = datetime.utcnow() utcnow = datetime.utcnow()
assert utcnow == dt_util.as_utc(utcnow).replace(tzinfo=None) assert utcnow == dt_util.as_utc(utcnow).replace(tzinfo=None)
def test_as_utc_with_utc_object(self):
def test_as_utc_with_utc_object():
"""Test UTC time with UTC object.""" """Test UTC time with UTC object."""
utcnow = dt_util.utcnow() utcnow = dt_util.utcnow()
assert utcnow == dt_util.as_utc(utcnow) assert utcnow == dt_util.as_utc(utcnow)
def test_as_utc_with_local_object(self):
def test_as_utc_with_local_object():
"""Test the UTC time with local object.""" """Test the UTC time with local object."""
dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE)) dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE))
localnow = dt_util.now() localnow = dt_util.now()
@ -77,18 +79,21 @@ class TestDateUtil(unittest.TestCase):
assert localnow == utcnow assert localnow == utcnow
assert localnow.tzinfo != utcnow.tzinfo assert localnow.tzinfo != utcnow.tzinfo
def test_as_local_with_naive_object(self):
def test_as_local_with_naive_object():
"""Test local time with native object.""" """Test local time with native object."""
now = dt_util.now() now = dt_util.now()
assert abs(now-dt_util.as_local(datetime.utcnow())) < \ assert abs(now-dt_util.as_local(datetime.utcnow())) < \
timedelta(seconds=1) timedelta(seconds=1)
def test_as_local_with_local_object(self):
def test_as_local_with_local_object():
"""Test local with local object.""" """Test local with local object."""
now = dt_util.now() now = dt_util.now()
assert now == now assert now == now
def test_as_local_with_utc_object(self):
def test_as_local_with_utc_object():
"""Test local time with UTC object.""" """Test local time with UTC object."""
dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE)) dt_util.set_default_time_zone(dt_util.get_time_zone(TEST_TIME_ZONE))
@ -98,12 +103,14 @@ class TestDateUtil(unittest.TestCase):
assert localnow == utcnow assert localnow == utcnow
assert localnow.tzinfo != utcnow.tzinfo assert localnow.tzinfo != utcnow.tzinfo
def test_utc_from_timestamp(self):
def test_utc_from_timestamp():
"""Test utc_from_timestamp method.""" """Test utc_from_timestamp method."""
assert datetime(1986, 7, 9, tzinfo=dt_util.UTC) == \ assert datetime(1986, 7, 9, tzinfo=dt_util.UTC) == \
dt_util.utc_from_timestamp(521251200) dt_util.utc_from_timestamp(521251200)
def test_as_timestamp(self):
def test_as_timestamp():
"""Test as_timestamp method.""" """Test as_timestamp method."""
ts = 1462401234 ts = 1462401234
utc_dt = dt_util.utc_from_timestamp(ts) utc_dt = dt_util.utc_from_timestamp(ts)
@ -114,9 +121,10 @@ class TestDateUtil(unittest.TestCase):
# confirm the ability to handle a string passed in # confirm the ability to handle a string passed in
delta = dt_util.as_timestamp("2016-01-01 12:12:12") delta = dt_util.as_timestamp("2016-01-01 12:12:12")
delta -= dt_util.as_timestamp("2016-01-01 12:12:11") delta -= dt_util.as_timestamp("2016-01-01 12:12:11")
assert 1 == delta assert delta == 1
def test_parse_datetime_converts_correctly(self):
def test_parse_datetime_converts_correctly():
"""Test parse_datetime converts strings.""" """Test parse_datetime converts strings."""
assert \ assert \
datetime(1986, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC) == \ datetime(1986, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC) == \
@ -126,11 +134,13 @@ class TestDateUtil(unittest.TestCase):
assert utcnow == dt_util.parse_datetime(utcnow.isoformat()) assert utcnow == dt_util.parse_datetime(utcnow.isoformat())
def test_parse_datetime_returns_none_for_incorrect_format(self):
def test_parse_datetime_returns_none_for_incorrect_format():
"""Test parse_datetime returns None if incorrect format.""" """Test parse_datetime returns None if incorrect format."""
assert dt_util.parse_datetime("not a datetime string") is None assert dt_util.parse_datetime("not a datetime string") is None
def test_get_age(self):
def test_get_age():
"""Test get_age.""" """Test get_age."""
diff = dt_util.now() - timedelta(seconds=0) diff = dt_util.now() - timedelta(seconds=0)
assert dt_util.get_age(diff) == "0 seconds" assert dt_util.get_age(diff) == "0 seconds"
@ -162,7 +172,8 @@ class TestDateUtil(unittest.TestCase):
diff = dt_util.now() - timedelta(minutes=365*60*24) diff = dt_util.now() - timedelta(minutes=365*60*24)
assert dt_util.get_age(diff) == "1 year" assert dt_util.get_age(diff) == "1 year"
def test_parse_time_expression(self):
def test_parse_time_expression():
"""Test parse_time_expression.""" """Test parse_time_expression."""
assert [x for x in range(60)] == \ assert [x for x in range(60)] == \
dt_util.parse_time_expression('*', 0, 59) dt_util.parse_time_expression('*', 0, 59)
@ -184,7 +195,8 @@ class TestDateUtil(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
dt_util.parse_time_expression(61, 0, 60) dt_util.parse_time_expression(61, 0, 60)
def test_find_next_time_expression_time_basic(self):
def test_find_next_time_expression_time_basic():
"""Test basic stuff for find_next_time_expression_time.""" """Test basic stuff for find_next_time_expression_time."""
def find(dt, hour, minute, second): def find(dt, hour, minute, second):
"""Call test_find_next_time_expression_time.""" """Call test_find_next_time_expression_time."""
@ -207,7 +219,8 @@ class TestDateUtil(unittest.TestCase):
assert datetime(2018, 10, 8, 5, 0, 0) == \ assert datetime(2018, 10, 8, 5, 0, 0) == \
find(datetime(2018, 10, 7, 10, 30, 0), 5, 0, 0) find(datetime(2018, 10, 7, 10, 30, 0), 5, 0, 0)
def test_find_next_time_expression_time_dst(self):
def test_find_next_time_expression_time_dst():
"""Test daylight saving time for find_next_time_expression_time.""" """Test daylight saving time for find_next_time_expression_time."""
tz = dt_util.get_time_zone('Europe/Vienna') tz = dt_util.get_time_zone('Europe/Vienna')
dt_util.set_default_time_zone(tz) dt_util.set_default_time_zone(tz)

View file

@ -1,72 +1,72 @@
"""Test Home Assistant util methods.""" """Test Home Assistant util methods."""
import unittest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta from datetime import datetime, timedelta
from homeassistant import util
import homeassistant.util.dt as dt_util
import pytest import pytest
from homeassistant import util
import homeassistant.util.dt as dt_util
class TestUtil(unittest.TestCase):
"""Test util methods."""
def test_sanitize_filename(self): def test_sanitize_filename():
"""Test sanitize_filename.""" """Test sanitize_filename."""
assert "test" == util.sanitize_filename("test") assert util.sanitize_filename("test") == 'test'
assert "test" == util.sanitize_filename("/test") assert util.sanitize_filename("/test") == 'test'
assert "test" == util.sanitize_filename("..test") assert util.sanitize_filename("..test") == 'test'
assert "test" == util.sanitize_filename("\\test") assert util.sanitize_filename("\\test") == 'test'
assert "test" == util.sanitize_filename("\\../test") assert util.sanitize_filename("\\../test") == 'test'
def test_sanitize_path(self):
def test_sanitize_path():
"""Test sanitize_path.""" """Test sanitize_path."""
assert "test/path" == util.sanitize_path("test/path") assert util.sanitize_path("test/path") == 'test/path'
assert "test/path" == util.sanitize_path("~test/path") assert util.sanitize_path("~test/path") == 'test/path'
assert "//test/path" == util.sanitize_path("~/../test/path") assert util.sanitize_path("~/../test/path") == '//test/path'
def test_slugify(self):
def test_slugify():
"""Test slugify.""" """Test slugify."""
assert "t_est" == util.slugify("T-!@#$!#@$!$est") assert util.slugify("T-!@#$!#@$!$est") == 't_est'
assert "test_more" == util.slugify("Test More") assert util.slugify("Test More") == 'test_more'
assert "test_more" == util.slugify("Test_(More)") assert util.slugify("Test_(More)") == 'test_more'
assert "test_more" == util.slugify("Tèst_Mörê") assert util.slugify("Tèst_Mörê") == 'test_more'
assert "b8_27_eb_00_00_00" == util.slugify("B8:27:EB:00:00:00") assert util.slugify("B8:27:EB:00:00:00") == 'b8_27_eb_00_00_00'
assert "test_com" == util.slugify("test.com") assert util.slugify("test.com") == 'test_com'
assert "greg_phone_exp_wayp1" == \ assert util.slugify("greg_phone - exp_wayp1") == 'greg_phone_exp_wayp1'
util.slugify("greg_phone - exp_wayp1") assert util.slugify("We are, we are, a... Test Calendar") == \
assert "we_are_we_are_a_test_calendar" == \ 'we_are_we_are_a_test_calendar'
util.slugify("We are, we are, a... Test Calendar") assert util.slugify("Tèst_äöüß_ÄÖÜ") == 'test_aouss_aou'
assert "test_aouss_aou" == util.slugify("Tèst_äöüß_ÄÖÜ") assert util.slugify("影師嗎") == 'ying_shi_ma'
assert "ying_shi_ma" == util.slugify("影師嗎") assert util.slugify("けいふぉんと") == 'keihuonto'
assert "keihuonto" == util.slugify("けいふぉんと")
def test_repr_helper(self):
def test_repr_helper():
"""Test repr_helper.""" """Test repr_helper."""
assert "A" == util.repr_helper("A") assert util.repr_helper("A") == 'A'
assert "5" == util.repr_helper(5) assert util.repr_helper(5) == '5'
assert "True" == util.repr_helper(True) assert util.repr_helper(True) == 'True'
assert "test=1" == util.repr_helper({"test": 1}) assert util.repr_helper({"test": 1}) == 'test=1'
assert "1986-07-09T12:00:00+00:00" == \ assert util.repr_helper(datetime(1986, 7, 9, 12, 0, 0)) == \
util.repr_helper(datetime(1986, 7, 9, 12, 0, 0)) '1986-07-09T12:00:00+00:00'
def test_convert(self):
def test_convert():
"""Test convert.""" """Test convert."""
assert 5 == util.convert("5", int) assert util.convert("5", int) == 5
assert 5.0 == util.convert("5", float) assert util.convert("5", float) == 5.0
assert util.convert("True", bool) is True assert util.convert("True", bool) is True
assert 1 == util.convert("NOT A NUMBER", int, 1) assert util.convert("NOT A NUMBER", int, 1) == 1
assert 1 == util.convert(None, int, 1) assert util.convert(None, int, 1) == 1
assert 1 == util.convert(object, int, 1) assert util.convert(object, int, 1) == 1
def test_ensure_unique_string(self):
def test_ensure_unique_string():
"""Test ensure_unique_string.""" """Test ensure_unique_string."""
assert "Beer_3" == \ assert util.ensure_unique_string("Beer", ["Beer", "Beer_2"]) == 'Beer_3'
util.ensure_unique_string("Beer", ["Beer", "Beer_2"]) assert util.ensure_unique_string("Beer", ["Wine", "Soda"]) == 'Beer'
assert "Beer" == \
util.ensure_unique_string("Beer", ["Wine", "Soda"])
def test_ordered_enum(self):
def test_ordered_enum():
"""Test the ordered enum class.""" """Test the ordered enum class."""
class TestEnum(util.OrderedEnum): class TestEnum(util.OrderedEnum):
"""Test enum that can be ordered.""" """Test enum that can be ordered."""
@ -77,18 +77,18 @@ class TestUtil(unittest.TestCase):
assert TestEnum.SECOND >= TestEnum.FIRST assert TestEnum.SECOND >= TestEnum.FIRST
assert TestEnum.SECOND >= TestEnum.SECOND assert TestEnum.SECOND >= TestEnum.SECOND
assert not (TestEnum.SECOND >= TestEnum.THIRD) assert TestEnum.SECOND < TestEnum.THIRD
assert TestEnum.SECOND > TestEnum.FIRST assert TestEnum.SECOND > TestEnum.FIRST
assert not (TestEnum.SECOND > TestEnum.SECOND)
assert not (TestEnum.SECOND > TestEnum.THIRD)
assert not (TestEnum.SECOND <= TestEnum.FIRST)
assert TestEnum.SECOND <= TestEnum.SECOND assert TestEnum.SECOND <= TestEnum.SECOND
assert TestEnum.SECOND <= TestEnum.THIRD assert TestEnum.SECOND <= TestEnum.THIRD
assert not (TestEnum.SECOND < TestEnum.FIRST) assert TestEnum.SECOND > TestEnum.FIRST
assert not (TestEnum.SECOND < TestEnum.SECOND) assert TestEnum.SECOND <= TestEnum.SECOND
assert TestEnum.SECOND <= TestEnum.THIRD
assert TestEnum.SECOND >= TestEnum.FIRST
assert TestEnum.SECOND >= TestEnum.SECOND
assert TestEnum.SECOND < TestEnum.THIRD assert TestEnum.SECOND < TestEnum.THIRD
# Python will raise a TypeError if the <, <=, >, >= methods # Python will raise a TypeError if the <, <=, >, >= methods
@ -105,7 +105,8 @@ class TestUtil(unittest.TestCase):
with pytest.raises(TypeError): with pytest.raises(TypeError):
TestEnum.FIRST >= 1 TestEnum.FIRST >= 1
def test_throttle(self):
def test_throttle():
"""Test the add cooldown decorator.""" """Test the add cooldown decorator."""
calls1 = [] calls1 = []
calls2 = [] calls2 = []
@ -126,38 +127,39 @@ class TestUtil(unittest.TestCase):
test_throttle1() test_throttle1()
test_throttle2() test_throttle2()
assert 1 == len(calls1) assert len(calls1) == 1
assert 1 == len(calls2) assert len(calls2) == 1
# Call second time. Methods should not get called # Call second time. Methods should not get called
test_throttle1() test_throttle1()
test_throttle2() test_throttle2()
assert 1 == len(calls1) assert len(calls1) == 1
assert 1 == len(calls2) assert len(calls2) == 1
# Call again, overriding throttle, only first one should fire # Call again, overriding throttle, only first one should fire
test_throttle1(no_throttle=True) test_throttle1(no_throttle=True)
test_throttle2(no_throttle=True) test_throttle2(no_throttle=True)
assert 2 == len(calls1) assert len(calls1) == 2
assert 1 == len(calls2) assert len(calls2) == 1
with patch('homeassistant.util.utcnow', return_value=plus3): with patch('homeassistant.util.utcnow', return_value=plus3):
test_throttle1() test_throttle1()
test_throttle2() test_throttle2()
assert 2 == len(calls1) assert len(calls1) == 2
assert 1 == len(calls2) assert len(calls2) == 1
with patch('homeassistant.util.utcnow', return_value=plus5): with patch('homeassistant.util.utcnow', return_value=plus5):
test_throttle1() test_throttle1()
test_throttle2() test_throttle2()
assert 3 == len(calls1) assert len(calls1) == 3
assert 2 == len(calls2) assert len(calls2) == 2
def test_throttle_per_instance(self):
def test_throttle_per_instance():
"""Test that the throttle method is done per instance of a class.""" """Test that the throttle method is done per instance of a class."""
class Tester: class Tester:
"""A tester class for the throttle.""" """A tester class for the throttle."""
@ -170,7 +172,8 @@ class TestUtil(unittest.TestCase):
assert Tester().hello() assert Tester().hello()
assert Tester().hello() assert Tester().hello()
def test_throttle_on_method(self):
def test_throttle_on_method():
"""Test that throttle works when wrapping a method.""" """Test that throttle works when wrapping a method."""
class Tester: class Tester:
"""A tester class for the throttle.""" """A tester class for the throttle."""
@ -185,7 +188,8 @@ class TestUtil(unittest.TestCase):
assert throttled() assert throttled()
assert throttled() is None assert throttled() is None
def test_throttle_on_two_method(self):
def test_throttle_on_two_method():
"""Test that throttle works when wrapping two methods.""" """Test that throttle works when wrapping two methods."""
class Tester: class Tester:
"""A test class for the throttle.""" """A test class for the throttle."""
@ -205,8 +209,9 @@ class TestUtil(unittest.TestCase):
assert tester.hello() assert tester.hello()
assert tester.goodbye() assert tester.goodbye()
@patch.object(util, 'random') @patch.object(util, 'random')
def test_get_random_string(self, mock_random): def test_get_random_string(mock_random):
"""Test get random string.""" """Test get random string."""
results = ['A', 'B', 'C'] results = ['A', 'B', 'C']

View file

@ -2,15 +2,16 @@
from json import JSONEncoder from json import JSONEncoder
import os import os
import unittest import unittest
from unittest.mock import Mock
import sys import sys
from tempfile import mkdtemp from tempfile import mkdtemp
import pytest
from homeassistant.util.json import ( from homeassistant.util.json import (
SerializationError, load_json, save_json) SerializationError, load_json, save_json)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import pytest
from unittest.mock import Mock
# Test data that can be saved as JSON # Test data that can be saved as JSON
TEST_JSON_A = {"a": 1, "B": "two"} TEST_JSON_A = {"a": 1, "B": "two"}
@ -19,66 +20,73 @@ TEST_JSON_B = {"a": "one", "B": 2}
TEST_BAD_OBJECT = {("A",): 1} TEST_BAD_OBJECT = {("A",): 1}
# Test data that can not be loaded as JSON # Test data that can not be loaded as JSON
TEST_BAD_SERIALIED = "THIS IS NOT JSON\n" TEST_BAD_SERIALIED = "THIS IS NOT JSON\n"
TMP_DIR = None
class TestJSON(unittest.TestCase): def setup():
"""Test util.json save and load."""
def setUp(self):
"""Set up for tests.""" """Set up for tests."""
self.tmp_dir = mkdtemp() global TMP_DIR
TMP_DIR = mkdtemp()
def tearDown(self):
def teardown():
"""Clean up after tests.""" """Clean up after tests."""
for fname in os.listdir(self.tmp_dir): for fname in os.listdir(TMP_DIR):
os.remove(os.path.join(self.tmp_dir, fname)) os.remove(os.path.join(TMP_DIR, fname))
os.rmdir(self.tmp_dir) os.rmdir(TMP_DIR)
def _path_for(self, leaf_name):
return os.path.join(self.tmp_dir, leaf_name+".json")
def test_save_and_load(self): def _path_for(leaf_name):
return os.path.join(TMP_DIR, leaf_name+".json")
def test_save_and_load():
"""Test saving and loading back.""" """Test saving and loading back."""
fname = self._path_for("test1") fname = _path_for("test1")
save_json(fname, TEST_JSON_A) save_json(fname, TEST_JSON_A)
data = load_json(fname) data = load_json(fname)
assert data == TEST_JSON_A assert data == TEST_JSON_A
# Skipped on Windows # Skipped on Windows
@unittest.skipIf(sys.platform.startswith('win'), @unittest.skipIf(sys.platform.startswith('win'),
"private permissions not supported on Windows") "private permissions not supported on Windows")
def test_save_and_load_private(self): def test_save_and_load_private():
"""Test we can load private files and that they are protected.""" """Test we can load private files and that they are protected."""
fname = self._path_for("test2") fname = _path_for("test2")
save_json(fname, TEST_JSON_A, private=True) save_json(fname, TEST_JSON_A, private=True)
data = load_json(fname) data = load_json(fname)
assert data == TEST_JSON_A assert data == TEST_JSON_A
stats = os.stat(fname) stats = os.stat(fname)
assert stats.st_mode & 0o77 == 0 assert stats.st_mode & 0o77 == 0
def test_overwrite_and_reload(self):
def test_overwrite_and_reload():
"""Test that we can overwrite an existing file and read back.""" """Test that we can overwrite an existing file and read back."""
fname = self._path_for("test3") fname = _path_for("test3")
save_json(fname, TEST_JSON_A) save_json(fname, TEST_JSON_A)
save_json(fname, TEST_JSON_B) save_json(fname, TEST_JSON_B)
data = load_json(fname) data = load_json(fname)
assert data == TEST_JSON_B assert data == TEST_JSON_B
def test_save_bad_data(self):
def test_save_bad_data():
"""Test error from trying to save unserialisable data.""" """Test error from trying to save unserialisable data."""
fname = self._path_for("test4") fname = _path_for("test4")
with pytest.raises(SerializationError): with pytest.raises(SerializationError):
save_json(fname, TEST_BAD_OBJECT) save_json(fname, TEST_BAD_OBJECT)
def test_load_bad_data(self):
def test_load_bad_data():
"""Test error from trying to load unserialisable data.""" """Test error from trying to load unserialisable data."""
fname = self._path_for("test5") fname = _path_for("test5")
with open(fname, "w") as fh: with open(fname, "w") as fh:
fh.write(TEST_BAD_SERIALIED) fh.write(TEST_BAD_SERIALIED)
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
load_json(fname) load_json(fname)
def test_custom_encoder(self):
def test_custom_encoder():
"""Test serializing with a custom encoder.""" """Test serializing with a custom encoder."""
class MockJSONEncoder(JSONEncoder): class MockJSONEncoder(JSONEncoder):
"""Mock JSON encoder.""" """Mock JSON encoder."""
@ -87,7 +95,7 @@ class TestJSON(unittest.TestCase):
"""Mock JSON encode method.""" """Mock JSON encode method."""
return "9" return "9"
fname = self._path_for("test6") fname = _path_for("test6")
save_json(fname, Mock(), encoder=MockJSONEncoder) save_json(fname, Mock(), encoder=MockJSONEncoder)
data = load_json(fname) data = load_json(fname)
self.assertEqual(data, "9") assert data == "9"

View file

@ -1,5 +1,4 @@
"""Test homeassistant pressure utility functions.""" """Test homeassistant pressure utility functions."""
import unittest
import pytest import pytest
from homeassistant.const import (PRESSURE_PA, PRESSURE_HPA, PRESSURE_MBAR, from homeassistant.const import (PRESSURE_PA, PRESSURE_HPA, PRESSURE_MBAR,
@ -10,17 +9,15 @@ INVALID_SYMBOL = 'bob'
VALID_SYMBOL = PRESSURE_PA VALID_SYMBOL = PRESSURE_PA
class TestPressureUtil(unittest.TestCase): def test_convert_same_unit():
"""Test the pressure utility functions."""
def test_convert_same_unit(self):
"""Test conversion from any unit to same unit.""" """Test conversion from any unit to same unit."""
assert pressure_util.convert(2, PRESSURE_PA, PRESSURE_PA) == 2 assert pressure_util.convert(2, PRESSURE_PA, PRESSURE_PA) == 2
assert pressure_util.convert(3, PRESSURE_HPA, PRESSURE_HPA) == 3 assert pressure_util.convert(3, PRESSURE_HPA, PRESSURE_HPA) == 3
assert pressure_util.convert(4, PRESSURE_MBAR, PRESSURE_MBAR) == 4 assert pressure_util.convert(4, PRESSURE_MBAR, PRESSURE_MBAR) == 4
assert pressure_util.convert(5, PRESSURE_INHG, PRESSURE_INHG) == 5 assert pressure_util.convert(5, PRESSURE_INHG, PRESSURE_INHG) == 5
def test_convert_invalid_unit(self):
def test_convert_invalid_unit():
"""Test exception is thrown for invalid units.""" """Test exception is thrown for invalid units."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
pressure_util.convert(5, INVALID_SYMBOL, VALID_SYMBOL) pressure_util.convert(5, INVALID_SYMBOL, VALID_SYMBOL)
@ -28,39 +25,34 @@ class TestPressureUtil(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pressure_util.convert(5, VALID_SYMBOL, INVALID_SYMBOL) pressure_util.convert(5, VALID_SYMBOL, INVALID_SYMBOL)
def test_convert_nonnumeric_value(self):
def test_convert_nonnumeric_value():
"""Test exception is thrown for nonnumeric type.""" """Test exception is thrown for nonnumeric type."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
pressure_util.convert('a', PRESSURE_HPA, PRESSURE_INHG) pressure_util.convert('a', PRESSURE_HPA, PRESSURE_INHG)
def test_convert_from_hpascals(self):
def test_convert_from_hpascals():
"""Test conversion from hPA to other units.""" """Test conversion from hPA to other units."""
hpascals = 1000 hpascals = 1000
self.assertAlmostEqual( assert pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_PSI) == \
pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_PSI), pytest.approx(14.5037743897)
14.5037743897) assert pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_INHG) == \
self.assertAlmostEqual( pytest.approx(29.5299801647)
pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_INHG), assert pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_PA) == \
29.5299801647) pytest.approx(100000)
self.assertAlmostEqual( assert pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_MBAR) == \
pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_PA), pytest.approx(1000)
100000)
self.assertAlmostEqual(
pressure_util.convert(hpascals, PRESSURE_HPA, PRESSURE_MBAR),
1000)
def test_convert_from_inhg(self):
def test_convert_from_inhg():
"""Test conversion from inHg to other units.""" """Test conversion from inHg to other units."""
inhg = 30 inhg = 30
self.assertAlmostEqual( assert pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_PSI) == \
pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_PSI), pytest.approx(14.7346266155)
14.7346266155) assert pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_HPA) == \
self.assertAlmostEqual( pytest.approx(1015.9167)
pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_HPA), assert pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_PA) == \
1015.9167) pytest.approx(101591.67)
self.assertAlmostEqual( assert pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_MBAR) == \
pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_PA), pytest.approx(1015.9167)
101591.67)
self.assertAlmostEqual(
pressure_util.convert(inhg, PRESSURE_INHG, PRESSURE_MBAR),
1015.9167)

View file

@ -1,6 +1,5 @@
"""Test Home Assistant ruamel.yaml loader.""" """Test Home Assistant ruamel.yaml loader."""
import os import os
import unittest
from tempfile import mkdtemp from tempfile import mkdtemp
import pytest import pytest
@ -114,44 +113,50 @@ views:
cards: !include cards.yaml cards: !include cards.yaml
""" """
TMP_DIR = None
class TestYAML(unittest.TestCase):
"""Test lovelace.yaml save and load."""
def setUp(self): def setup():
"""Set up for tests.""" """Set up for tests."""
self.tmp_dir = mkdtemp() global TMP_DIR
self.yaml = YAML(typ='rt') TMP_DIR = mkdtemp()
def tearDown(self):
def teardown():
"""Clean up after tests.""" """Clean up after tests."""
for fname in os.listdir(self.tmp_dir): for fname in os.listdir(TMP_DIR):
os.remove(os.path.join(self.tmp_dir, fname)) os.remove(os.path.join(TMP_DIR, fname))
os.rmdir(self.tmp_dir) os.rmdir(TMP_DIR)
def _path_for(self, leaf_name):
return os.path.join(self.tmp_dir, leaf_name+".yaml")
def test_save_and_load(self): def _path_for(leaf_name):
return os.path.join(TMP_DIR, leaf_name+".yaml")
def test_save_and_load():
"""Test saving and loading back.""" """Test saving and loading back."""
fname = self._path_for("test1") yaml = YAML(typ='rt')
fname = _path_for("test1")
open(fname, "w+").close() open(fname, "w+").close()
util_yaml.save_yaml(fname, self.yaml.load(TEST_YAML_A)) util_yaml.save_yaml(fname, yaml.load(TEST_YAML_A))
data = util_yaml.load_yaml(fname, True) data = util_yaml.load_yaml(fname, True)
assert data == self.yaml.load(TEST_YAML_A) assert data == yaml.load(TEST_YAML_A)
def test_overwrite_and_reload(self):
def test_overwrite_and_reload():
"""Test that we can overwrite an existing file and read back.""" """Test that we can overwrite an existing file and read back."""
fname = self._path_for("test2") yaml = YAML(typ='rt')
fname = _path_for("test2")
open(fname, "w+").close() open(fname, "w+").close()
util_yaml.save_yaml(fname, self.yaml.load(TEST_YAML_A)) util_yaml.save_yaml(fname, yaml.load(TEST_YAML_A))
util_yaml.save_yaml(fname, self.yaml.load(TEST_YAML_B)) util_yaml.save_yaml(fname, yaml.load(TEST_YAML_B))
data = util_yaml.load_yaml(fname, True) data = util_yaml.load_yaml(fname, True)
assert data == self.yaml.load(TEST_YAML_B) assert data == yaml.load(TEST_YAML_B)
def test_load_bad_data(self):
def test_load_bad_data():
"""Test error from trying to load unserialisable data.""" """Test error from trying to load unserialisable data."""
fname = self._path_for("test3") fname = _path_for("test3")
with open(fname, "w") as fh: with open(fname, "w") as fh:
fh.write(TEST_BAD_YAML) fh.write(TEST_BAD_YAML)
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):

View file

@ -1,5 +1,5 @@
"""Test the unit system helper.""" """Test the unit system helper."""
import unittest import pytest
from homeassistant.util.unit_system import ( from homeassistant.util.unit_system import (
UnitSystem, UnitSystem,
@ -19,16 +19,11 @@ from homeassistant.const import (
TEMPERATURE, TEMPERATURE,
VOLUME VOLUME
) )
import pytest
SYSTEM_NAME = 'TEST' SYSTEM_NAME = 'TEST'
INVALID_UNIT = 'INVALID' INVALID_UNIT = 'INVALID'
class TestUnitSystem(unittest.TestCase): def test_invalid_units():
"""Test the unit system helper."""
def test_invalid_units(self):
"""Test errors are raised when invalid units are passed in.""" """Test errors are raised when invalid units are passed in."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
UnitSystem(SYSTEM_NAME, INVALID_UNIT, LENGTH_METERS, VOLUME_LITERS, UnitSystem(SYSTEM_NAME, INVALID_UNIT, LENGTH_METERS, VOLUME_LITERS,
@ -50,7 +45,8 @@ class TestUnitSystem(unittest.TestCase):
UnitSystem(SYSTEM_NAME, TEMP_CELSIUS, LENGTH_METERS, VOLUME_LITERS, UnitSystem(SYSTEM_NAME, TEMP_CELSIUS, LENGTH_METERS, VOLUME_LITERS,
MASS_GRAMS, INVALID_UNIT) MASS_GRAMS, INVALID_UNIT)
def test_invalid_value(self):
def test_invalid_value():
"""Test no conversion happens if value is non-numeric.""" """Test no conversion happens if value is non-numeric."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
METRIC_SYSTEM.length('25a', LENGTH_KILOMETERS) METRIC_SYSTEM.length('25a', LENGTH_KILOMETERS)
@ -61,7 +57,8 @@ class TestUnitSystem(unittest.TestCase):
with pytest.raises(TypeError): with pytest.raises(TypeError):
METRIC_SYSTEM.pressure('50Pa', PRESSURE_PA) METRIC_SYSTEM.pressure('50Pa', PRESSURE_PA)
def test_as_dict(self):
def test_as_dict():
"""Test that the as_dict() method returns the expected dictionary.""" """Test that the as_dict() method returns the expected dictionary."""
expected = { expected = {
LENGTH: LENGTH_KILOMETERS, LENGTH: LENGTH_KILOMETERS,
@ -73,79 +70,78 @@ class TestUnitSystem(unittest.TestCase):
assert expected == METRIC_SYSTEM.as_dict() assert expected == METRIC_SYSTEM.as_dict()
def test_temperature_same_unit(self):
"""Test no conversion happens if to unit is same as from unit."""
assert 5 == \
METRIC_SYSTEM.temperature(5,
METRIC_SYSTEM.temperature_unit)
def test_temperature_unknown_unit(self): def test_temperature_same_unit():
"""Test no conversion happens if to unit is same as from unit."""
assert METRIC_SYSTEM.temperature(5, METRIC_SYSTEM.temperature_unit) == 5
def test_temperature_unknown_unit():
"""Test no conversion happens if unknown unit.""" """Test no conversion happens if unknown unit."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
METRIC_SYSTEM.temperature(5, 'K') METRIC_SYSTEM.temperature(5, 'K')
def test_temperature_to_metric(self):
def test_temperature_to_metric():
"""Test temperature conversion to metric system.""" """Test temperature conversion to metric system."""
assert 25 == \ assert METRIC_SYSTEM.temperature(25, METRIC_SYSTEM.temperature_unit) == 25
METRIC_SYSTEM.temperature(25, METRIC_SYSTEM.temperature_unit) assert round(METRIC_SYSTEM.temperature(
assert 26.7 == \ 80, IMPERIAL_SYSTEM.temperature_unit), 1) == 26.7
round(METRIC_SYSTEM.temperature(
80, IMPERIAL_SYSTEM.temperature_unit), 1)
def test_temperature_to_imperial(self):
def test_temperature_to_imperial():
"""Test temperature conversion to imperial system.""" """Test temperature conversion to imperial system."""
assert 77 == \ assert IMPERIAL_SYSTEM.temperature(
IMPERIAL_SYSTEM.temperature(77, IMPERIAL_SYSTEM.temperature_unit) 77, IMPERIAL_SYSTEM.temperature_unit) == 77
assert 77 == \ assert IMPERIAL_SYSTEM.temperature(
IMPERIAL_SYSTEM.temperature(25, METRIC_SYSTEM.temperature_unit) 25, METRIC_SYSTEM.temperature_unit) == 77
def test_length_unknown_unit(self):
def test_length_unknown_unit():
"""Test length conversion with unknown from unit.""" """Test length conversion with unknown from unit."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
METRIC_SYSTEM.length(5, 'fr') METRIC_SYSTEM.length(5, 'fr')
def test_length_to_metric(self):
def test_length_to_metric():
"""Test length conversion to metric system.""" """Test length conversion to metric system."""
assert 100 == \ assert METRIC_SYSTEM.length(100, METRIC_SYSTEM.length_unit) == 100
METRIC_SYSTEM.length(100, METRIC_SYSTEM.length_unit) assert METRIC_SYSTEM.length(5, IMPERIAL_SYSTEM.length_unit) == 8.04672
assert 8.04672 == \
METRIC_SYSTEM.length(5, IMPERIAL_SYSTEM.length_unit)
def test_length_to_imperial(self):
def test_length_to_imperial():
"""Test length conversion to imperial system.""" """Test length conversion to imperial system."""
assert 100 == \ assert IMPERIAL_SYSTEM.length(100, IMPERIAL_SYSTEM.length_unit) == 100
IMPERIAL_SYSTEM.length(100, assert IMPERIAL_SYSTEM.length(5, METRIC_SYSTEM.length_unit) == 3.106855
IMPERIAL_SYSTEM.length_unit)
assert 3.106855 == \
IMPERIAL_SYSTEM.length(5, METRIC_SYSTEM.length_unit)
def test_pressure_same_unit(self):
def test_pressure_same_unit():
"""Test no conversion happens if to unit is same as from unit.""" """Test no conversion happens if to unit is same as from unit."""
assert 5 == \ assert METRIC_SYSTEM.pressure(5, METRIC_SYSTEM.pressure_unit) == 5
METRIC_SYSTEM.pressure(5, METRIC_SYSTEM.pressure_unit)
def test_pressure_unknown_unit(self):
def test_pressure_unknown_unit():
"""Test no conversion happens if unknown unit.""" """Test no conversion happens if unknown unit."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
METRIC_SYSTEM.pressure(5, 'K') METRIC_SYSTEM.pressure(5, 'K')
def test_pressure_to_metric(self):
def test_pressure_to_metric():
"""Test pressure conversion to metric system.""" """Test pressure conversion to metric system."""
assert 25 == \ assert METRIC_SYSTEM.pressure(25, METRIC_SYSTEM.pressure_unit) == 25
METRIC_SYSTEM.pressure(25, METRIC_SYSTEM.pressure_unit) assert METRIC_SYSTEM.pressure(14.7, IMPERIAL_SYSTEM.pressure_unit) == \
self.assertAlmostEqual( pytest.approx(101352.932, abs=1e-1)
METRIC_SYSTEM.pressure(14.7, IMPERIAL_SYSTEM.pressure_unit),
101352.932, places=1)
def test_pressure_to_imperial(self):
def test_pressure_to_imperial():
"""Test pressure conversion to imperial system.""" """Test pressure conversion to imperial system."""
assert 77 == \ assert IMPERIAL_SYSTEM.pressure(77, IMPERIAL_SYSTEM.pressure_unit) == 77
IMPERIAL_SYSTEM.pressure(77, IMPERIAL_SYSTEM.pressure_unit) assert IMPERIAL_SYSTEM.pressure(
self.assertAlmostEqual( 101352.932, METRIC_SYSTEM.pressure_unit) == \
IMPERIAL_SYSTEM.pressure(101352.932, METRIC_SYSTEM.pressure_unit), pytest.approx(14.7, abs=1e-4)
14.7, places=4)
def test_properties(self):
def test_properties():
"""Test the unit properties are returned as expected.""" """Test the unit properties are returned as expected."""
assert LENGTH_KILOMETERS == METRIC_SYSTEM.length_unit assert LENGTH_KILOMETERS == METRIC_SYSTEM.length_unit
assert TEMP_CELSIUS == METRIC_SYSTEM.temperature_unit assert TEMP_CELSIUS == METRIC_SYSTEM.temperature_unit
@ -153,7 +149,8 @@ class TestUnitSystem(unittest.TestCase):
assert VOLUME_LITERS == METRIC_SYSTEM.volume_unit assert VOLUME_LITERS == METRIC_SYSTEM.volume_unit
assert PRESSURE_PA == METRIC_SYSTEM.pressure_unit assert PRESSURE_PA == METRIC_SYSTEM.pressure_unit
def test_is_metric(self):
def test_is_metric():
"""Test the is metric flag.""" """Test the is metric flag."""
assert METRIC_SYSTEM.is_metric assert METRIC_SYSTEM.is_metric
assert not IMPERIAL_SYSTEM.is_metric assert not IMPERIAL_SYSTEM.is_metric

View file

@ -1,29 +1,23 @@
"""Test homeassistant volume utility functions.""" """Test homeassistant volume utility functions."""
import unittest import pytest
import homeassistant.util.volume as volume_util import homeassistant.util.volume as volume_util
from homeassistant.const import (VOLUME_LITERS, VOLUME_MILLILITERS, from homeassistant.const import (VOLUME_LITERS, VOLUME_MILLILITERS,
VOLUME_GALLONS, VOLUME_FLUID_OUNCE) VOLUME_GALLONS, VOLUME_FLUID_OUNCE)
import pytest
INVALID_SYMBOL = 'bob' INVALID_SYMBOL = 'bob'
VALID_SYMBOL = VOLUME_LITERS VALID_SYMBOL = VOLUME_LITERS
class TestVolumeUtil(unittest.TestCase): def test_convert_same_unit():
"""Test the volume utility functions."""
def test_convert_same_unit(self):
"""Test conversion from any unit to same unit.""" """Test conversion from any unit to same unit."""
assert 2 == volume_util.convert(2, VOLUME_LITERS, VOLUME_LITERS) assert volume_util.convert(2, VOLUME_LITERS, VOLUME_LITERS) == 2
assert 3 == volume_util.convert(3, VOLUME_MILLILITERS, assert volume_util.convert(3, VOLUME_MILLILITERS, VOLUME_MILLILITERS) == 3
VOLUME_MILLILITERS) assert volume_util.convert(4, VOLUME_GALLONS, VOLUME_GALLONS) == 4
assert 4 == volume_util.convert(4, VOLUME_GALLONS, assert volume_util.convert(5, VOLUME_FLUID_OUNCE, VOLUME_FLUID_OUNCE) == 5
VOLUME_GALLONS)
assert 5 == volume_util.convert(5, VOLUME_FLUID_OUNCE,
VOLUME_FLUID_OUNCE)
def test_convert_invalid_unit(self):
def test_convert_invalid_unit():
"""Test exception is thrown for invalid units.""" """Test exception is thrown for invalid units."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
volume_util.convert(5, INVALID_SYMBOL, VALID_SYMBOL) volume_util.convert(5, INVALID_SYMBOL, VALID_SYMBOL)
@ -31,18 +25,20 @@ class TestVolumeUtil(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
volume_util.convert(5, VALID_SYMBOL, INVALID_SYMBOL) volume_util.convert(5, VALID_SYMBOL, INVALID_SYMBOL)
def test_convert_nonnumeric_value(self):
def test_convert_nonnumeric_value():
"""Test exception is thrown for nonnumeric type.""" """Test exception is thrown for nonnumeric type."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
volume_util.convert('a', VOLUME_GALLONS, VOLUME_LITERS) volume_util.convert('a', VOLUME_GALLONS, VOLUME_LITERS)
def test_convert_from_liters(self):
def test_convert_from_liters():
"""Test conversion from liters to other units.""" """Test conversion from liters to other units."""
liters = 5 liters = 5
assert volume_util.convert(liters, VOLUME_LITERS, assert volume_util.convert(liters, VOLUME_LITERS, VOLUME_GALLONS) == 1.321
VOLUME_GALLONS) == 1.321
def test_convert_from_gallons(self):
def test_convert_from_gallons():
"""Test conversion from gallons to other units.""" """Test conversion from gallons to other units."""
gallons = 5 gallons = 5
assert volume_util.convert(gallons, VOLUME_GALLONS, assert volume_util.convert(gallons, VOLUME_GALLONS,

View file

@ -21,40 +21,39 @@ def mock_credstash():
yield mock_credstash yield mock_credstash
class TestYaml(unittest.TestCase): def test_simple_list():
"""Test util.yaml loader."""
# pylint: disable=no-self-use, invalid-name
def test_simple_list(self):
"""Test simple list.""" """Test simple list."""
conf = "config:\n - simple\n - list" conf = "config:\n - simple\n - list"
with io.StringIO(conf) as file: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc['config'] == ["simple", "list"] assert doc['config'] == ["simple", "list"]
def test_simple_dict(self):
def test_simple_dict():
"""Test simple dict.""" """Test simple dict."""
conf = "key: value" conf = "key: value"
with io.StringIO(conf) as file: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc['key'] == 'value' assert doc['key'] == 'value'
def test_unhashable_key(self):
def test_unhashable_key():
"""Test an unhasable key.""" """Test an unhasable key."""
files = {YAML_CONFIG_FILE: 'message:\n {{ states.state }}'} files = {YAML_CONFIG_FILE: 'message:\n {{ states.state }}'}
with pytest.raises(HomeAssistantError), \ with pytest.raises(HomeAssistantError), \
patch_yaml_files(files): patch_yaml_files(files):
load_yaml_config_file(YAML_CONFIG_FILE) load_yaml_config_file(YAML_CONFIG_FILE)
def test_no_key(self):
def test_no_key():
"""Test item without a key.""" """Test item without a key."""
files = {YAML_CONFIG_FILE: 'a: a\nnokeyhere'} files = {YAML_CONFIG_FILE: 'a: a\nnokeyhere'}
with pytest.raises(HomeAssistantError), \ with pytest.raises(HomeAssistantError), \
patch_yaml_files(files): patch_yaml_files(files):
yaml.load_yaml(YAML_CONFIG_FILE) yaml.load_yaml(YAML_CONFIG_FILE)
def test_environment_variable(self):
def test_environment_variable():
"""Test config file with environment variable.""" """Test config file with environment variable."""
os.environ["PASSWORD"] = "secret_password" os.environ["PASSWORD"] = "secret_password"
conf = "password: !env_var PASSWORD" conf = "password: !env_var PASSWORD"
@ -63,21 +62,24 @@ class TestYaml(unittest.TestCase):
assert doc['password'] == "secret_password" assert doc['password'] == "secret_password"
del os.environ["PASSWORD"] del os.environ["PASSWORD"]
def test_environment_variable_default(self):
def test_environment_variable_default():
"""Test config file with default value for environment variable.""" """Test config file with default value for environment variable."""
conf = "password: !env_var PASSWORD secret_password" conf = "password: !env_var PASSWORD secret_password"
with io.StringIO(conf) as file: with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc['password'] == "secret_password" assert doc['password'] == "secret_password"
def test_invalid_environment_variable(self):
def test_invalid_environment_variable():
"""Test config file with no environment variable sat.""" """Test config file with no environment variable sat."""
conf = "password: !env_var PASSWORD" conf = "password: !env_var PASSWORD"
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
with io.StringIO(conf) as file: with io.StringIO(conf) as file:
yaml.yaml.safe_load(file) yaml.yaml.safe_load(file)
def test_include_yaml(self):
def test_include_yaml():
"""Test include yaml.""" """Test include yaml."""
with patch_yaml_files({'test.yaml': 'value'}): with patch_yaml_files({'test.yaml': 'value'}):
conf = 'key: !include test.yaml' conf = 'key: !include test.yaml'
@ -91,8 +93,9 @@ class TestYaml(unittest.TestCase):
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc["key"] == {} assert doc["key"] == {}
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_list(self, mock_walk): def test_include_dir_list(mock_walk):
"""Test include dir list yaml.""" """Test include dir list yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', [], ['two.yaml', 'one.yaml']], ['/tmp', [], ['two.yaml', 'one.yaml']],
@ -107,8 +110,9 @@ class TestYaml(unittest.TestCase):
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc["key"] == sorted(["one", "two"]) assert doc["key"] == sorted(["one", "two"])
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_list_recursive(self, mock_walk): def test_include_dir_list_recursive(mock_walk):
"""Test include dir recursive list yaml.""" """Test include dir recursive list yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', ['tmp2', '.ignore', 'ignore'], ['zero.yaml']], ['/tmp', ['tmp2', '.ignore', 'ignore'], ['zero.yaml']],
@ -130,8 +134,9 @@ class TestYaml(unittest.TestCase):
assert '.ignore' not in mock_walk.return_value[0][1] assert '.ignore' not in mock_walk.return_value[0][1]
assert sorted(doc["key"]) == sorted(["zero", "one", "two"]) assert sorted(doc["key"]) == sorted(["zero", "one", "two"])
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_named(self, mock_walk): def test_include_dir_named(mock_walk):
"""Test include dir named yaml.""" """Test include dir named yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', [], ['first.yaml', 'second.yaml', 'secrets.yaml']] ['/tmp', [], ['first.yaml', 'second.yaml', 'secrets.yaml']]
@ -147,8 +152,9 @@ class TestYaml(unittest.TestCase):
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert doc["key"] == correct assert doc["key"] == correct
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_named_recursive(self, mock_walk): def test_include_dir_named_recursive(mock_walk):
"""Test include dir named yaml.""" """Test include dir named yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']], ['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']],
@ -171,8 +177,9 @@ class TestYaml(unittest.TestCase):
assert '.ignore' not in mock_walk.return_value[0][1] assert '.ignore' not in mock_walk.return_value[0][1]
assert doc["key"] == correct assert doc["key"] == correct
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_merge_list(self, mock_walk): def test_include_dir_merge_list(mock_walk):
"""Test include dir merge list yaml.""" """Test include dir merge list yaml."""
mock_walk.return_value = [['/tmp', [], ['first.yaml', 'second.yaml']]] mock_walk.return_value = [['/tmp', [], ['first.yaml', 'second.yaml']]]
@ -185,8 +192,9 @@ class TestYaml(unittest.TestCase):
doc = yaml.yaml.safe_load(file) doc = yaml.yaml.safe_load(file)
assert sorted(doc["key"]) == sorted(["one", "two", "three"]) assert sorted(doc["key"]) == sorted(["one", "two", "three"])
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_merge_list_recursive(self, mock_walk): def test_include_dir_merge_list_recursive(mock_walk):
"""Test include dir merge list yaml.""" """Test include dir merge list yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']], ['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']],
@ -209,8 +217,9 @@ class TestYaml(unittest.TestCase):
assert sorted(doc["key"]) == sorted(["one", "two", assert sorted(doc["key"]) == sorted(["one", "two",
"three", "four"]) "three", "four"])
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_merge_named(self, mock_walk): def test_include_dir_merge_named(mock_walk):
"""Test include dir merge named yaml.""" """Test include dir merge named yaml."""
mock_walk.return_value = [['/tmp', [], ['first.yaml', 'second.yaml']]] mock_walk.return_value = [['/tmp', [], ['first.yaml', 'second.yaml']]]
@ -229,8 +238,9 @@ class TestYaml(unittest.TestCase):
"key3": "three" "key3": "three"
} }
@patch('homeassistant.util.yaml.os.walk') @patch('homeassistant.util.yaml.os.walk')
def test_include_dir_merge_named_recursive(self, mock_walk): def test_include_dir_merge_named_recursive(mock_walk):
"""Test include dir merge named yaml.""" """Test include dir merge named yaml."""
mock_walk.return_value = [ mock_walk.return_value = [
['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']], ['/tmp', ['tmp2', '.ignore', 'ignore'], ['first.yaml']],
@ -257,18 +267,21 @@ class TestYaml(unittest.TestCase):
"key4": "four" "key4": "four"
} }
@patch('homeassistant.util.yaml.open', create=True) @patch('homeassistant.util.yaml.open', create=True)
def test_load_yaml_encoding_error(self, mock_open): def test_load_yaml_encoding_error(mock_open):
"""Test raising a UnicodeDecodeError.""" """Test raising a UnicodeDecodeError."""
mock_open.side_effect = UnicodeDecodeError('', b'', 1, 0, '') mock_open.side_effect = UnicodeDecodeError('', b'', 1, 0, '')
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
yaml.load_yaml('test') yaml.load_yaml('test')
def test_dump(self):
def test_dump():
"""The that the dump method returns empty None values.""" """The that the dump method returns empty None values."""
assert yaml.dump({'a': None, 'b': 'b'}) == 'a:\nb: b\n' assert yaml.dump({'a': None, 'b': 'b'}) == 'a:\nb: b\n'
def test_dump_unicode(self):
def test_dump_unicode():
"""The that the dump method returns empty None values.""" """The that the dump method returns empty None values."""
assert yaml.dump({'a': None, 'b': 'привет'}) == 'a:\nb: привет\n' assert yaml.dump({'a': None, 'b': 'привет'}) == 'a:\nb: привет\n'