Add tests for entity component

This commit is contained in:
Paulus Schoutsen 2016-01-30 18:55:52 -08:00
parent 6418634f3a
commit 90e17fc77f
4 changed files with 278 additions and 32 deletions

View file

@ -50,8 +50,6 @@ class Entity(object):
""" ABC for Home Assistant entities. """
# pylint: disable=no-self-use
_hidden = False
# SAFE TO OVERWRITE
# The properties and methods here are safe to overwrite when inherting this
# class. These may be used to customize the behavior of the entity.
@ -103,13 +101,14 @@ class Entity(object):
""" Retrieve latest state. """
pass
entity_id = None
# DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation.
hass = None
entity_id = None
def update_ha_state(self, force_refresh=False):
"""

View file

@ -1,9 +1,4 @@
"""
homeassistant.helpers.entity_component
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Provides helpers for components that manage entities.
"""
"""Provides helpers for components that manage entities."""
from threading import Lock
from homeassistant.bootstrap import prepare_setup_platform
@ -18,14 +13,14 @@ DEFAULT_SCAN_INTERVAL = 15
class EntityComponent(object):
"""Helper class that will help a component manage its entities."""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
"""
Helper class that will help a component manage its entities.
"""
def __init__(self, logger, domain, hass,
scan_interval=DEFAULT_SCAN_INTERVAL,
discovery_platforms=None, group_name=None):
"""Initialize an entity component."""
self.logger = logger
self.hass = hass
@ -44,9 +39,10 @@ class EntityComponent(object):
def setup(self, config):
"""
Sets up a full entity component:
- Loads the platforms from the config
- Will listen for supported discovered platforms
Set up a full entity component.
Loads the platforms from the config and will listen for supported
discovered platforms.
"""
self.config = config
@ -57,13 +53,18 @@ class EntityComponent(object):
self._setup_platform(p_type, p_config)
if self.discovery_platforms:
discovery.listen(self.hass, self.discovery_platforms.keys(),
self._entity_discovered)
discovery.listen(
self.hass, self.discovery_platforms.keys(),
lambda service, info:
self._setup_platform(self.discovery_platforms[service], {},
info))
def add_entities(self, new_entities):
"""
Takes in a list of new entities. For each entity will see if it already
exists. If not, will add it, set it up and push the first state.
Add new entities to this component.
For each entity will see if it already exists. If not, will add it,
set it up and push the first state.
"""
with self.lock:
for entity in new_entities:
@ -101,8 +102,10 @@ class EntityComponent(object):
def extract_from_service(self, service):
"""
Takes a service and extracts all known entities.
Will return all if no entity IDs given in service.
Extract all known entities from a service call.
Will return all entities if no entities specified in call.
Will return an empty list if entities specified but unknown.
"""
with self.lock:
if ATTR_ENTITY_ID not in service.data:
@ -113,7 +116,7 @@ class EntityComponent(object):
if entity_id in self.entities]
def _update_entity_states(self, now):
""" Update the states of all the entities. """
"""Update the states of all the polling entities."""
with self.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
@ -125,16 +128,9 @@ class EntityComponent(object):
for entity in entities:
entity.update_ha_state(True)
def _entity_discovered(self, service, info):
""" Called when a entity is discovered. """
if service not in self.discovery_platforms:
return
self._setup_platform(self.discovery_platforms[service], {}, info)
def _setup_platform(self, platform_type, platform_config,
discovery_info=None):
""" Tries to setup a platform for this component. """
"""Setup a platform for this component."""
platform = prepare_setup_platform(
self.hass, self.config, self.domain, platform_type)

View file

@ -145,11 +145,26 @@ class MockHTTP(object):
class MockModule(object):
""" Provides a fake module. """
def __init__(self, domain, dependencies=[], setup=None):
def __init__(self, domain=None, dependencies=[], setup=None):
self.DOMAIN = domain
self.DEPENDENCIES = dependencies
# Setup a mock setup if none given.
self.setup = lambda hass, config: False if setup is None else setup
if setup is None:
self.setup = lambda hass, config: False
else:
self.setup = setup
class MockPlatform(object):
""" Provides a fake platform. """
def __init__(self, setup_platform=None, dependencies=[]):
self.DEPENDENCIES = dependencies
self._setup_platform = setup_platform
def setup_platform(self, hass, config, add_devices, discovery_info=None):
if self._setup_platform is not None:
self._setup_platform(hass, config, add_devices, discovery_info)
class MockToggleDevice(ToggleEntity):

View file

@ -0,0 +1,236 @@
"""
tests.test_helper_entity_component
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tests the entity component helper.
"""
# pylint: disable=protected-access,too-many-public-methods
from collections import OrderedDict
import logging
import unittest
from unittest.mock import patch, Mock
import homeassistant.core as ha
import homeassistant.loader as loader
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.components import discovery
from tests.common import get_test_home_assistant, MockPlatform, MockModule
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
class EntityTest(Entity):
def __init__(self, **values):
self._values = values
if 'entity_id' in values:
self.entity_id = values['entity_id']
@property
def name(self):
return self._handle('name')
@property
def should_poll(self):
return self._handle('should_poll')
@property
def unique_id(self):
return self._handle('unique_id')
def _handle(self, attr):
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
class TestHelpersEntityComponent(unittest.TestCase):
""" Tests homeassistant.helpers.entity_component module. """
def setUp(self): # pylint: disable=invalid-name
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name
"""Clean up the test Home Assistant instance."""
self.hass.stop()
def test_setting_up_group(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass,
group_name='everyone')
# No group after setup
assert 0 == len(self.hass.states.entity_ids())
component.add_entities([EntityTest(name='hello')])
# group exists
assert 2 == len(self.hass.states.entity_ids())
assert ['group.everyone'] == self.hass.states.entity_ids('group')
group = self.hass.states.get('group.everyone')
assert ('test_domain.hello',) == group.attributes.get('entity_id')
# group extended
component.add_entities([EntityTest(name='hello2')])
assert 3 == len(self.hass.states.entity_ids())
group = self.hass.states.get('group.everyone')
assert ['test_domain.hello', 'test_domain.hello2'] == \
sorted(group.attributes.get('entity_id'))
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_polling_only_updates_entities_it_should_poll(self, mock_track):
component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20)
no_poll_ent = EntityTest(should_poll=False)
no_poll_ent.update_ha_state = Mock()
poll_ent = EntityTest(should_poll=True)
poll_ent.update_ha_state = Mock()
component.add_entities([no_poll_ent])
assert not mock_track.called
component.add_entities([poll_ent])
assert mock_track.called
assert [0, 20, 40] == list(mock_track.call_args[1].get('second'))
no_poll_ent.update_ha_state.reset_mock()
poll_ent.update_ha_state.reset_mock()
component._update_entity_states(None)
assert not no_poll_ent.update_ha_state.called
assert poll_ent.update_ha_state.called
def test_update_state_adds_entities(self):
"""Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent1 = EntityTest()
ent2 = EntityTest(should_poll=True)
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update_ha_state = lambda *_: component.add_entities([ent1])
component._update_entity_states(None)
assert 2 == len(self.hass.states.entity_ids())
def test_not_adding_duplicate_entities(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert 0 == len(self.hass.states.entity_ids())
component.add_entities([None, EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids())
component.add_entities([EntityTest(unique_id='not_very_unique')])
assert 1 == len(self.hass.states.entity_ids())
def test_not_assigning_entity_id_if_prescribes_one(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert 'hello.world' not in self.hass.states.entity_ids()
component.add_entities([EntityTest(entity_id='hello.world')])
assert 'hello.world' in self.hass.states.entity_ids()
def test_extract_from_service_returns_all_if_no_entity_id(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.extract_from_service(call))
def test_extract_from_service_filter_out_non_existing_entities(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
assert ['test_domain.test_2'] == \
[ent.entity_id for ent in component.extract_from_service(call)]
def test_setup_loads_platforms(self):
component_setup = Mock(return_value=True)
platform_setup = Mock(return_value=None)
loader.set_component(
'test_component',
MockModule('test_component', setup=component_setup))
loader.set_component('test_domain.mod2',
MockPlatform(platform_setup, ['test_component']))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert not component_setup.called
assert not platform_setup.called
component.setup({
DOMAIN: {
'platform': 'mod2',
}
})
assert component_setup.called
assert platform_setup.called
def test_setup_recovers_when_setup_raises(self):
platform1_setup = Mock(side_effect=Exception('Broken'))
platform2_setup = Mock(return_value=None)
loader.set_component('test_domain.mod1', MockPlatform(platform1_setup))
loader.set_component('test_domain.mod2', MockPlatform(platform2_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
assert not platform1_setup.called
assert not platform2_setup.called
component.setup(OrderedDict([
(DOMAIN, {'platform': 'mod1'}),
("{} 2".format(DOMAIN), {'platform': 'non_exist'}),
("{} 3".format(DOMAIN), {'platform': 'mod2'}),
]))
assert platform1_setup.called
assert platform2_setup.called
@patch('homeassistant.helpers.entity_component.EntityComponent'
'._setup_platform')
def test_setup_does_discovery(self, mock_setup):
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, discovery_platforms={
'discovery.test': 'platform_test',
})
component.setup({})
self.hass.bus.fire(discovery.EVENT_PLATFORM_DISCOVERED, {
discovery.ATTR_SERVICE: 'discovery.test',
discovery.ATTR_DISCOVERED: 'discovery_info',
})
self.hass.pool.block_till_done()
assert mock_setup.called
assert ('platform_test', {}, 'discovery_info') == \
mock_setup.call_args[0]