"""Tests for the EntityPlatform helper.""" import asyncio import logging import unittest from unittest.mock import patch, Mock, MagicMock from datetime import timedelta import homeassistant.loader as loader from homeassistant.helpers.entity import generate_entity_id from homeassistant.helpers.entity_component import ( EntityComponent, DEFAULT_SCAN_INTERVAL) from homeassistant.helpers import entity_platform, entity_registry import homeassistant.util.dt as dt_util from tests.common import ( get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, MockEntity) _LOGGER = logging.getLogger(__name__) DOMAIN = "test_domain" PLATFORM = 'test_platform' class MockEntityPlatform(entity_platform.EntityPlatform): """Mock class with some mock defaults.""" def __init__( self, hass, logger=None, domain=DOMAIN, platform_name=PLATFORM, scan_interval=timedelta(seconds=15), parallel_updates=0, entity_namespace=None, async_entities_added_callback=lambda: None ): """Initialize a mock entity platform.""" super().__init__( hass=hass, logger=logger, domain=domain, platform_name=platform_name, scan_interval=scan_interval, parallel_updates=parallel_updates, entity_namespace=entity_namespace, async_entities_added_callback=async_entities_added_callback, ) class TestHelpersEntityPlatform(unittest.TestCase): """Test homeassistant.helpers.entity_component module.""" def setUp(self): # pylint: disable=invalid-name """Initialize a test Home Assistant instance.""" self.hass = get_test_home_assistant() def tearDown(self): # pylint: disable=invalid-name """Clean up the test Home Assistant instance.""" self.hass.stop() def test_polling_only_updates_entities_it_should_poll(self): """Test the polling of only updated entities.""" component = EntityComponent( _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) no_poll_ent = MockEntity(should_poll=False) no_poll_ent.async_update = Mock() poll_ent = MockEntity(should_poll=True) poll_ent.async_update = Mock() component.add_entities([no_poll_ent, poll_ent]) no_poll_ent.async_update.reset_mock() poll_ent.async_update.reset_mock() fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) self.hass.block_till_done() assert not no_poll_ent.async_update.called assert poll_ent.async_update.called def test_polling_updates_entities_with_exception(self): """Test the updated entities that not break with an exception.""" component = EntityComponent( _LOGGER, DOMAIN, self.hass, timedelta(seconds=20)) update_ok = [] update_err = [] def update_mock(): """Mock normal update.""" update_ok.append(None) def update_mock_err(): """Mock error update.""" update_err.append(None) raise AssertionError("Fake error update") ent1 = MockEntity(should_poll=True) ent1.update = update_mock_err ent2 = MockEntity(should_poll=True) ent2.update = update_mock ent3 = MockEntity(should_poll=True) ent3.update = update_mock ent4 = MockEntity(should_poll=True) ent4.update = update_mock component.add_entities([ent1, ent2, ent3, ent4]) update_ok.clear() update_err.clear() fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20)) self.hass.block_till_done() assert len(update_ok) == 3 assert len(update_err) == 1 def test_update_state_adds_entities(self): """Test if updating poll entities cause an entity to be added works.""" component = EntityComponent(_LOGGER, DOMAIN, self.hass) ent1 = MockEntity() ent2 = MockEntity(should_poll=True) component.add_entities([ent2]) assert 1 == len(self.hass.states.entity_ids()) ent2.update = lambda *_: component.add_entities([ent1]) fire_time_changed( self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL ) self.hass.block_till_done() assert 2 == len(self.hass.states.entity_ids()) def test_update_state_adds_entities_with_update_before_add_true(self): """Test if call update before add to state machine.""" component = EntityComponent(_LOGGER, DOMAIN, self.hass) ent = MockEntity() ent.update = Mock(spec_set=True) component.add_entities([ent], True) self.hass.block_till_done() assert 1 == len(self.hass.states.entity_ids()) assert ent.update.called def test_update_state_adds_entities_with_update_before_add_false(self): """Test if not call update before add to state machine.""" component = EntityComponent(_LOGGER, DOMAIN, self.hass) ent = MockEntity() ent.update = Mock(spec_set=True) component.add_entities([ent], False) self.hass.block_till_done() assert 1 == len(self.hass.states.entity_ids()) assert not ent.update.called @patch('homeassistant.helpers.entity_platform.' 'async_track_time_interval') def test_set_scan_interval_via_platform(self, mock_track): """Test the setting of the scan interval via platform.""" def platform_setup(hass, config, add_devices, discovery_info=None): """Test the platform setup.""" add_devices([MockEntity(should_poll=True)]) platform = MockPlatform(platform_setup) platform.SCAN_INTERVAL = timedelta(seconds=30) loader.set_component('test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, self.hass) component.setup({ DOMAIN: { 'platform': 'platform', } }) self.hass.block_till_done() assert mock_track.called assert timedelta(seconds=30) == mock_track.call_args[0][2] def test_adding_entities_with_generator_and_thread_callback(self): """Test generator in add_entities that calls thread method. We should make sure we resolve the generator to a list before passing it into an async context. """ component = EntityComponent(_LOGGER, DOMAIN, self.hass) def create_entity(number): """Create entity helper.""" entity = MockEntity() entity.entity_id = generate_entity_id(DOMAIN + '.{}', 'Number', hass=self.hass) return entity component.add_entities(create_entity(i) for i in range(2)) @asyncio.coroutine def test_platform_warn_slow_setup(hass): """Warn we log when platform setup takes a long time.""" platform = MockPlatform() loader.set_component('test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) with patch.object(hass.loop, 'call_later', MagicMock()) \ as mock_call: yield from component.async_setup({ DOMAIN: { 'platform': 'platform', } }) assert mock_call.called timeout, logger_method = mock_call.mock_calls[0][1][:2] assert timeout == entity_platform.SLOW_SETUP_WARNING assert logger_method == _LOGGER.warning assert mock_call().cancel.called @asyncio.coroutine def test_platform_error_slow_setup(hass, caplog): """Don't block startup more than SLOW_SETUP_MAX_WAIT.""" with patch.object(entity_platform, 'SLOW_SETUP_MAX_WAIT', 0): called = [] @asyncio.coroutine def setup_platform(*args): called.append(1) yield from asyncio.sleep(1, loop=hass.loop) platform = MockPlatform(async_setup_platform=setup_platform) component = EntityComponent(_LOGGER, DOMAIN, hass) loader.set_component('test_domain.test_platform', platform) yield from component.async_setup({ DOMAIN: { 'platform': 'test_platform', } }) assert len(called) == 1 assert 'test_domain.test_platform' not in hass.config.components assert 'test_platform is taking longer than 0 seconds' in caplog.text @asyncio.coroutine def test_updated_state_used_for_entity_id(hass): """Test that first update results used for entity ID generation.""" component = EntityComponent(_LOGGER, DOMAIN, hass) class MockEntityNameFetcher(MockEntity): """Mock entity that fetches a friendly name.""" @asyncio.coroutine def async_update(self): """Mock update that assigns a name.""" self._values['name'] = "Living Room" yield from component.async_add_entities([MockEntityNameFetcher()], True) entity_ids = hass.states.async_entity_ids() assert 1 == len(entity_ids) assert entity_ids[0] == "test_domain.living_room" @asyncio.coroutine def test_parallel_updates_async_platform(hass): """Warn we log when platform setup takes a long time.""" platform = MockPlatform() @asyncio.coroutine def mock_update(*args, **kwargs): pass platform.async_setup_platform = mock_update loader.set_component('test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} yield from component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] assert handle.parallel_updates is None @asyncio.coroutine def test_parallel_updates_async_platform_with_constant(hass): """Warn we log when platform setup takes a long time.""" platform = MockPlatform() @asyncio.coroutine def mock_update(*args, **kwargs): pass platform.async_setup_platform = mock_update platform.PARALLEL_UPDATES = 1 loader.set_component('test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} yield from component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] assert handle.parallel_updates is not None @asyncio.coroutine def test_parallel_updates_sync_platform(hass): """Warn we log when platform setup takes a long time.""" platform = MockPlatform(setup_platform=lambda *args: None) loader.set_component('test_domain.platform', platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component._platforms = {} yield from component.async_setup({ DOMAIN: { 'platform': 'platform', } }) handle = list(component._platforms.values())[-1] assert handle.parallel_updates is not None @asyncio.coroutine def test_raise_error_on_update(hass): """Test the add entity if they raise an error on update.""" updates = [] component = EntityComponent(_LOGGER, DOMAIN, hass) entity1 = MockEntity(name='test_1') entity2 = MockEntity(name='test_2') def _raise(): """Helper to raise an exception.""" raise AssertionError entity1.update = _raise entity2.update = lambda: updates.append(1) yield from component.async_add_entities([entity1, entity2], True) assert len(updates) == 1 assert 1 in updates @asyncio.coroutine def test_async_remove_with_platform(hass): """Remove an entity from a platform.""" component = EntityComponent(_LOGGER, DOMAIN, hass) entity1 = MockEntity(name='test_1') yield from component.async_add_entities([entity1]) assert len(hass.states.async_entity_ids()) == 1 yield from entity1.async_remove() assert len(hass.states.async_entity_ids()) == 0 @asyncio.coroutine def test_not_adding_duplicate_entities_with_unique_id(hass): """Test for not adding duplicate entities.""" component = EntityComponent(_LOGGER, DOMAIN, hass) yield from component.async_add_entities([ MockEntity(name='test1', unique_id='not_very_unique')]) assert len(hass.states.async_entity_ids()) == 1 yield from component.async_add_entities([ MockEntity(name='test2', unique_id='not_very_unique')]) assert len(hass.states.async_entity_ids()) == 1 @asyncio.coroutine def test_using_prescribed_entity_id(hass): """Test for using predefined entity ID.""" component = EntityComponent(_LOGGER, DOMAIN, hass) yield from component.async_add_entities([ MockEntity(name='bla', entity_id='hello.world')]) assert 'hello.world' in hass.states.async_entity_ids() @asyncio.coroutine def test_using_prescribed_entity_id_with_unique_id(hass): """Test for ammending predefined entity ID because currently exists.""" component = EntityComponent(_LOGGER, DOMAIN, hass) yield from component.async_add_entities([ MockEntity(entity_id='test_domain.world')]) yield from component.async_add_entities([ MockEntity(entity_id='test_domain.world', unique_id='bla')]) assert 'test_domain.world_2' in hass.states.async_entity_ids() @asyncio.coroutine def test_using_prescribed_entity_id_which_is_registered(hass): """Test not allowing predefined entity ID that already registered.""" component = EntityComponent(_LOGGER, DOMAIN, hass) registry = mock_registry(hass) # Register test_domain.world registry.async_get_or_create( DOMAIN, 'test', '1234', suggested_object_id='world') # This entity_id will be rewritten yield from component.async_add_entities([ MockEntity(entity_id='test_domain.world')]) assert 'test_domain.world_2' in hass.states.async_entity_ids() @asyncio.coroutine def test_name_which_conflict_with_registered(hass): """Test not generating conflicting entity ID based on name.""" component = EntityComponent(_LOGGER, DOMAIN, hass) registry = mock_registry(hass) # Register test_domain.world registry.async_get_or_create( DOMAIN, 'test', '1234', suggested_object_id='world') yield from component.async_add_entities([ MockEntity(name='world')]) assert 'test_domain.world_2' in hass.states.async_entity_ids() @asyncio.coroutine def test_entity_with_name_and_entity_id_getting_registered(hass): """Ensure that entity ID is used for registration.""" component = EntityComponent(_LOGGER, DOMAIN, hass) yield from component.async_add_entities([ MockEntity(unique_id='1234', name='bla', entity_id='test_domain.world')]) assert 'test_domain.world' in hass.states.async_entity_ids() @asyncio.coroutine def test_overriding_name_from_registry(hass): """Test that we can override a name via the Entity Registry.""" component = EntityComponent(_LOGGER, DOMAIN, hass) mock_registry(hass, { 'test_domain.world': entity_registry.RegistryEntry( entity_id='test_domain.world', unique_id='1234', # Using component.async_add_entities is equal to platform "domain" platform='test_domain', name='Overridden' ) }) yield from component.async_add_entities([ MockEntity(unique_id='1234', name='Device Name')]) state = hass.states.get('test_domain.world') assert state is not None assert state.name == 'Overridden' @asyncio.coroutine def test_registry_respect_entity_namespace(hass): """Test that the registry respects entity namespace.""" mock_registry(hass) platform = MockEntityPlatform(hass, entity_namespace='ns') entity = MockEntity(unique_id='1234', name='Device Name') yield from platform.async_add_entities([entity]) assert entity.entity_id == 'test_domain.ns_device_name' @asyncio.coroutine def test_registry_respect_entity_disabled(hass): """Test that the registry respects entity disabled.""" mock_registry(hass, { 'test_domain.world': entity_registry.RegistryEntry( entity_id='test_domain.world', unique_id='1234', # Using component.async_add_entities is equal to platform "domain" platform='test_platform', disabled_by=entity_registry.DISABLED_USER ) }) platform = MockEntityPlatform(hass) entity = MockEntity(unique_id='1234') yield from platform.async_add_entities([entity]) assert entity.entity_id is None assert hass.states.async_entity_ids() == []