Allow renaming entities in entity registry (#12636)

* Allow renaming entities in entity registry

* Lint
This commit is contained in:
Paulus Schoutsen 2018-02-24 10:53:59 -08:00 committed by GitHub
parent 2821820281
commit 6d431c3fc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 333 additions and 47 deletions

View file

@ -13,7 +13,8 @@ from homeassistant.util.yaml import load_yaml, dump
DOMAIN = 'config' DOMAIN = 'config'
DEPENDENCIES = ['http'] DEPENDENCIES = ['http']
SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script') SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script',
'entity_registry')
ON_DEMAND = ('zwave',) ON_DEMAND = ('zwave',)
FEATURE_FLAGS = ('config_entries',) FEATURE_FLAGS = ('config_entries',)

View file

@ -0,0 +1,55 @@
"""HTTP views to interact with the entity registry."""
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.helpers.entity_registry import async_get_registry
async def async_setup(hass):
"""Enable the Entity Registry views."""
hass.http.register_view(ConfigManagerEntityView)
return True
class ConfigManagerEntityView(HomeAssistantView):
"""View to interact with an entity registry entry."""
url = '/api/config/entity_registry/{entity_id}'
name = 'api:config:entity_registry:entity'
async def get(self, request, entity_id):
"""Get the entity registry settings for an entity."""
hass = request.app['hass']
registry = await async_get_registry(hass)
entry = registry.entities.get(entity_id)
if entry is None:
return self.json_message('Entry not found', 404)
return self.json(_entry_dict(entry))
@RequestDataValidator(vol.Schema({
# If passed in, we update value. Passing None will remove old value.
vol.Optional('name'): vol.Any(str, None),
}))
async def post(self, request, entity_id, data):
"""Update the entity registry settings for an entity."""
hass = request.app['hass']
registry = await async_get_registry(hass)
if entity_id not in registry.entities:
return self.json_message('Entry not found', 404)
entry = registry.async_update_entity(entity_id, **data)
return self.json(_entry_dict(entry))
@callback
def _entry_dict(entry):
"""Helper to convert entry to API format."""
return {
'entity_id': entry.entity_id,
'name': entry.name
}

View file

@ -28,11 +28,11 @@ SUPPORT_DEMO = (SUPPORT_BRIGHTNESS | SUPPORT_COLOR_TEMP | SUPPORT_EFFECT |
def setup_platform(hass, config, add_devices_callback, discovery_info=None): def setup_platform(hass, config, add_devices_callback, discovery_info=None):
"""Set up the demo light platform.""" """Set up the demo light platform."""
add_devices_callback([ add_devices_callback([
DemoLight("Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST, DemoLight(1, "Bed Light", False, True, effect_list=LIGHT_EFFECT_LIST,
effect=LIGHT_EFFECT_LIST[0]), effect=LIGHT_EFFECT_LIST[0]),
DemoLight("Ceiling Lights", True, True, DemoLight(2, "Ceiling Lights", True, True,
LIGHT_COLORS[0], LIGHT_TEMPS[1]), LIGHT_COLORS[0], LIGHT_TEMPS[1]),
DemoLight("Kitchen Lights", True, True, DemoLight(3, "Kitchen Lights", True, True,
LIGHT_COLORS[1], LIGHT_TEMPS[0]) LIGHT_COLORS[1], LIGHT_TEMPS[0])
]) ])
@ -40,10 +40,11 @@ def setup_platform(hass, config, add_devices_callback, discovery_info=None):
class DemoLight(Light): class DemoLight(Light):
"""Representation of a demo light.""" """Representation of a demo light."""
def __init__(self, name, state, available=False, rgb=None, ct=None, def __init__(self, unique_id, name, state, available=False, rgb=None,
brightness=180, xy_color=(.5, .5), white=200, ct=None, brightness=180, xy_color=(.5, .5), white=200,
effect_list=None, effect=None): effect_list=None, effect=None):
"""Initialize the light.""" """Initialize the light."""
self._unique_id = unique_id
self._name = name self._name = name
self._state = state self._state = state
self._rgb = rgb self._rgb = rgb
@ -64,6 +65,11 @@ class DemoLight(Light):
"""Return the name of the light if any.""" """Return the name of the light if any."""
return self._name return self._name
@property
def unique_id(self):
"""Return unique ID for light."""
return self._unique_id
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return availability.""" """Return availability."""

View file

@ -340,6 +340,12 @@ class Entity(object):
else: else:
self.hass.states.async_remove(self.entity_id) self.hass.states.async_remove(self.entity_id)
@callback
def async_registry_updated(self, old, new):
"""Called when the entity registry has been updated."""
self.registry_name = new.name
self.async_schedule_update_ha_state()
def __eq__(self, other): def __eq__(self, other):
"""Return the comparison.""" """Return the comparison."""
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):

View file

@ -10,12 +10,11 @@ from homeassistant.util.async import (
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .event import async_track_time_interval, async_track_point_in_time from .event import async_track_time_interval, async_track_point_in_time
from .entity_registry import EntityRegistry from .entity_registry import async_get_registry
SLOW_SETUP_WARNING = 10 SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60 SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10 PLATFORM_NOT_READY_RETRIES = 10
DATA_REGISTRY = 'entity_registry'
class EntityPlatform(object): class EntityPlatform(object):
@ -156,12 +155,7 @@ class EntityPlatform(object):
hass = self.hass hass = self.hass
component_entities = set(hass.states.async_entity_ids(self.domain)) component_entities = set(hass.states.async_entity_ids(self.domain))
registry = hass.data.get(DATA_REGISTRY) registry = yield from async_get_registry(hass)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
yield from registry.async_ensure_loaded()
tasks = [ tasks = [
self._async_add_entity(entity, update_before_add, self._async_add_entity(entity, update_before_add,
@ -226,6 +220,7 @@ class EntityPlatform(object):
entity.entity_id = entry.entity_id entity.entity_id = entry.entity_id
entity.registry_name = entry.name entity.registry_name = entry.name
entry.add_update_listener(entity)
# We won't generate an entity ID if the platform has already set one # We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID # We will however make sure that platform cannot pick a registered ID

View file

@ -15,17 +15,20 @@ from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
import os import os
import weakref
import attr import attr
from ..core import callback, split_entity_id from ..core import callback, split_entity_id
from ..loader import bind_hass
from ..util import ensure_unique_string, slugify from ..util import ensure_unique_string, slugify
from ..util.yaml import load_yaml, save_yaml from ..util.yaml import load_yaml, save_yaml
PATH_REGISTRY = 'entity_registry.yaml' PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry'
SAVE_DELAY = 10 SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DISABLED_HASS = 'hass' DISABLED_HASS = 'hass'
DISABLED_USER = 'user' DISABLED_USER = 'user'
@ -34,6 +37,8 @@ DISABLED_USER = 'user'
class RegistryEntry: class RegistryEntry:
"""Entity Registry Entry.""" """Entity Registry Entry."""
# pylint: disable=no-member
entity_id = attr.ib(type=str) entity_id = attr.ib(type=str)
unique_id = attr.ib(type=str) unique_id = attr.ib(type=str)
platform = attr.ib(type=str) platform = attr.ib(type=str)
@ -41,17 +46,27 @@ class RegistryEntry:
disabled_by = attr.ib( disabled_by = attr.ib(
type=str, default=None, type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None))) validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
domain = attr.ib(type=str, default=None, init=False, repr=False) update_listeners = attr.ib(type=list, default=attr.Factory(list),
repr=False)
domain = attr.ib(type=str, init=False, repr=False)
def __attrs_post_init__(self): @domain.default
"""Computed properties.""" def _domain_default(self):
object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0]) """Compute domain value."""
return split_entity_id(self.entity_id)[0]
@property @property
def disabled(self): def disabled(self):
"""Return if entry is disabled.""" """Return if entry is disabled."""
return self.disabled_by is not None return self.disabled_by is not None
def add_update_listener(self, listener):
"""Listen for when entry is updated.
Listener: Callback function(old_entry, new_entry)
"""
self.update_listeners.append(weakref.ref(listener))
class EntityRegistry: class EntityRegistry:
"""Class to hold a registry of entities.""" """Class to hold a registry of entities."""
@ -102,6 +117,39 @@ class EntityRegistry:
self.async_schedule_save() self.async_schedule_save()
return entity return entity
@callback
def async_update_entity(self, entity_id, *, name=_UNDEF):
"""Update properties of an entity."""
old = self.entities[entity_id]
changes = {}
if name is not _UNDEF and name != old.name:
changes['name'] = name
if not changes:
return old
new = self.entities[entity_id] = attr.evolve(old, **changes)
to_remove = []
for listener_ref in new.update_listeners:
listener = listener_ref()
if listener is None:
to_remove.append(listener)
else:
try:
listener.async_registry_updated(old, new)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error calling update listener')
for ref in to_remove:
new.update_listeners.remove(ref)
self.async_schedule_save()
return new
@asyncio.coroutine @asyncio.coroutine
def async_ensure_loaded(self): def async_ensure_loaded(self):
"""Load the registry from disk.""" """Load the registry from disk."""
@ -154,7 +202,20 @@ class EntityRegistry:
data[entry.entity_id] = { data[entry.entity_id] = {
'unique_id': entry.unique_id, 'unique_id': entry.unique_id,
'platform': entry.platform, 'platform': entry.platform,
'name': entry.name,
} }
yield from self.hass.async_add_job( yield from self.hass.async_add_job(
save_yaml, self.hass.config.path(PATH_REGISTRY), data) save_yaml, self.hass.config.path(PATH_REGISTRY), data)
@bind_hass
async def async_get_registry(hass) -> EntityRegistry:
"""Return entity registry instance."""
registry = hass.data.get(DATA_REGISTRY)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
await registry.async_ensure_loaded()
return registry

View file

@ -1,5 +1,6 @@
"""Test the helper method for writing tests.""" """Test the helper method for writing tests."""
import asyncio import asyncio
from datetime import timedelta
import functools as ft import functools as ft
import os import os
import sys import sys
@ -298,7 +299,7 @@ def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry.""" """Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass) registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or {} registry.entities = mock_entries or {}
hass.data[entity_platform.DATA_REGISTRY] = registry hass.data[entity_registry.DATA_REGISTRY] = registry
return registry return registry
@ -361,6 +362,32 @@ class MockPlatform(object):
self.async_setup_platform = mock_coro_func() self.async_setup_platform = mock_coro_func()
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
self, hass,
logger=None,
domain='test_domain',
platform_name='test_platform',
scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=lambda: None
):
"""Initialize a mock entity platform."""
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=async_entities_added_callback,
)
class MockToggleDevice(entity.ToggleEntity): class MockToggleDevice(entity.ToggleEntity):
"""Provide a mock toggle device.""" """Provide a mock toggle device."""

View file

@ -0,0 +1,134 @@
"""Test entity_registry API."""
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.components.config import entity_registry
from tests.common import mock_registry, MockEntity, MockEntityPlatform
@pytest.fixture
def client(hass, test_client):
"""Fixture that can interact with the config manager API."""
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
hass.loop.run_until_complete(entity_registry.async_setup(hass))
yield hass.loop.run_until_complete(test_client(hass.http.app))
async def test_get_entity(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.name': RegistryEntry(
entity_id='test_domain.name',
unique_id='1234',
platform='test_platform',
name='Hello World'
),
'test_domain.no_name': RegistryEntry(
entity_id='test_domain.no_name',
unique_id='6789',
platform='test_platform',
),
})
resp = await client.get(
'/api/config/entity_registry/test_domain.name')
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.name',
'name': 'Hello World'
}
resp = await client.get(
'/api/config/entity_registry/test_domain.no_name')
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.no_name',
'name': None
}
async def test_update_entity(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='before update'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'before update'
resp = await client.post(
'/api/config/entity_registry/test_domain.world', json={
'name': 'after update'
})
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.world',
'name': 'after update'
}
state = hass.states.get('test_domain.world')
assert state.name == 'after update'
async def test_update_entity_no_changes(hass, client):
"""Test get entry."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='name of entity'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'name of entity'
resp = await client.post(
'/api/config/entity_registry/test_domain.world', json={
'name': 'name of entity'
})
assert resp.status == 200
data = await resp.json()
assert data == {
'entity_id': 'test_domain.world',
'name': 'name of entity'
}
state = hass.states.get('test_domain.world')
assert state.name == 'name of entity'
async def test_get_nonexisting_entity(client):
"""Test get entry."""
resp = await client.get(
'/api/config/entity_registry/test_domain.non_existing')
assert resp.status == 404
async def test_update_nonexisting_entity(client):
"""Test get entry."""
resp = await client.post(
'/api/config/entity_registry/test_domain.non_existing', json={
'name': 'some name'
})
assert resp.status == 404

View file

@ -15,39 +15,13 @@ import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity) MockEntity, MockEntityPlatform)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain" DOMAIN = "test_domain"
PLATFORM = 'test_platform' PLATFORM = 'test_platform'
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
self, hass,
logger=None,
domain=DOMAIN,
platform_name=PLATFORM,
scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=lambda: None
):
"""Initialize a mock entity platform."""
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=async_entities_added_callback,
)
class TestHelpersEntityPlatform(unittest.TestCase): class TestHelpersEntityPlatform(unittest.TestCase):
"""Test homeassistant.helpers.entity_component module.""" """Test homeassistant.helpers.entity_component module."""
@ -510,3 +484,30 @@ def test_registry_respect_entity_disabled(hass):
yield from platform.async_add_entities([entity]) yield from platform.async_add_entities([entity])
assert entity.entity_id is None assert entity.entity_id is None
assert hass.states.async_entity_ids() == [] assert hass.states.async_entity_ids() == []
async def test_entity_registry_updates(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='before update'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'before update'
registry.async_update_entity('test_domain.world', name='after update')
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get('test_domain.world')
assert state.name == 'after update'