diff --git a/homeassistant/components/binary_sensor/deconz.py b/homeassistant/components/binary_sensor/deconz.py index 9aa0c446f2b..1fb62124407 100644 --- a/homeassistant/components/binary_sensor/deconz.py +++ b/homeassistant/components/binary_sensor/deconz.py @@ -116,15 +116,15 @@ class DeconzBinarySensor(BinarySensorDevice): return attr @property - def device(self): + def device_info(self): """Return a device description for device registry.""" if (self._sensor.uniqueid is None or self._sensor.uniqueid.count(':') != 7): return None serial = self._sensor.uniqueid.split('-', 1)[0] return { - 'connection': [[CONNECTION_ZIGBEE, serial]], - 'identifiers': [[DECONZ_DOMAIN, serial]], + 'connections': {(CONNECTION_ZIGBEE, serial)}, + 'identifiers': {(DECONZ_DOMAIN, serial)}, 'manufacturer': self._sensor.manufacturer, 'model': self._sensor.modelid, 'name': self._sensor.name, diff --git a/homeassistant/components/deconz/__init__.py b/homeassistant/components/deconz/__init__.py index d435e9e3c04..a4edc009ea1 100644 --- a/homeassistant/components/deconz/__init__.py +++ b/homeassistant/components/deconz/__init__.py @@ -123,8 +123,9 @@ async def async_setup_entry(hass, config_entry): device_registry = await \ hass.helpers.device_registry.async_get_registry() device_registry.async_get_or_create( - connection=[[CONNECTION_NETWORK_MAC, deconz.config.mac]], - identifiers=[[DOMAIN, deconz.config.bridgeid]], + config_entry=config_entry.entry_id, + connections={(CONNECTION_NETWORK_MAC, deconz.config.mac)}, + identifiers={(DOMAIN, deconz.config.bridgeid)}, manufacturer='Dresden Elektronik', model=deconz.config.modelid, name=deconz.config.name, sw_version=deconz.config.swversion) diff --git a/homeassistant/components/light/deconz.py b/homeassistant/components/light/deconz.py index 067f1474f96..412cf8693e5 100644 --- a/homeassistant/components/light/deconz.py +++ b/homeassistant/components/light/deconz.py @@ -202,15 +202,15 @@ class DeconzLight(Light): return attributes @property - def device(self): + def device_info(self): """Return a device description for device registry.""" if (self._light.uniqueid is None or self._light.uniqueid.count(':') != 7): return None serial = self._light.uniqueid.split('-', 1)[0] return { - 'connection': [[CONNECTION_ZIGBEE, serial]], - 'identifiers': [[DECONZ_DOMAIN, serial]], + 'connections': {(CONNECTION_ZIGBEE, serial)}, + 'identifiers': {(DECONZ_DOMAIN, serial)}, 'manufacturer': self._light.manufacturer, 'model': self._light.modelid, 'name': self._light.name, diff --git a/homeassistant/components/media_player/roku.py b/homeassistant/components/media_player/roku.py index fa1120db98c..fca7b29d2ec 100644 --- a/homeassistant/components/media_player/roku.py +++ b/homeassistant/components/media_player/roku.py @@ -87,7 +87,7 @@ class RokuDevice(MediaPlayerDevice): self.ip_address = host self.channels = [] self.current_app = None - self.device_info = {} + self._device_info = {} self.update() @@ -96,7 +96,7 @@ class RokuDevice(MediaPlayerDevice): import requests.exceptions try: - self.device_info = self.roku.device_info + self._device_info = self.roku.device_info self.ip_address = self.roku.host self.channels = self.get_source_list() @@ -121,9 +121,9 @@ class RokuDevice(MediaPlayerDevice): @property def name(self): """Return the name of the device.""" - if self.device_info.userdevicename: - return self.device_info.userdevicename - return "Roku {}".format(self.device_info.sernum) + if self._device_info.userdevicename: + return self._device_info.userdevicename + return "Roku {}".format(self._device_info.sernum) @property def state(self): @@ -149,7 +149,7 @@ class RokuDevice(MediaPlayerDevice): @property def unique_id(self): """Return a unique, HASS-friendly identifier for this entity.""" - return self.device_info.sernum + return self._device_info.sernum @property def media_content_type(self): diff --git a/homeassistant/components/media_player/soundtouch.py b/homeassistant/components/media_player/soundtouch.py index f2ac45a996f..489d028aad4 100644 --- a/homeassistant/components/media_player/soundtouch.py +++ b/homeassistant/components/media_player/soundtouch.py @@ -166,6 +166,11 @@ class SoundTouchDevice(MediaPlayerDevice): """Return specific soundtouch configuration.""" return self._config + @property + def device(self): + """Return Soundtouch device.""" + return self._device + def update(self): """Retrieve the latest data.""" self._status = self._device.status() @@ -318,8 +323,8 @@ class SoundTouchDevice(MediaPlayerDevice): _LOGGER.warning("Unable to create zone without slaves") else: _LOGGER.info("Creating zone with master %s", - self._device.config.name) - self._device.create_zone([slave.device for slave in slaves]) + self.device.config.name) + self.device.create_zone([slave.device for slave in slaves]) def remove_zone_slave(self, slaves): """ @@ -336,8 +341,8 @@ class SoundTouchDevice(MediaPlayerDevice): _LOGGER.warning("Unable to find slaves to remove") else: _LOGGER.info("Removing slaves from zone with master %s", - self._device.config.name) - self._device.remove_zone_slave([slave.device for slave in slaves]) + self.device.config.name) + self.device.remove_zone_slave([slave.device for slave in slaves]) def add_zone_slave(self, slaves): """ @@ -352,5 +357,5 @@ class SoundTouchDevice(MediaPlayerDevice): _LOGGER.warning("Unable to find slaves to add") else: _LOGGER.info("Adding slaves to zone with master %s", - self._device.config.name) - self._device.add_zone_slave([slave.device for slave in slaves]) + self.device.config.name) + self.device.add_zone_slave([slave.device for slave in slaves]) diff --git a/homeassistant/components/sensor/deconz.py b/homeassistant/components/sensor/deconz.py index 45c604a74ee..8cb3915dc46 100644 --- a/homeassistant/components/sensor/deconz.py +++ b/homeassistant/components/sensor/deconz.py @@ -136,15 +136,15 @@ class DeconzSensor(Entity): return attr @property - def device(self): + def device_info(self): """Return a device description for device registry.""" if (self._sensor.uniqueid is None or self._sensor.uniqueid.count(':') != 7): return None serial = self._sensor.uniqueid.split('-', 1)[0] return { - 'connection': [[CONNECTION_ZIGBEE, serial]], - 'identifiers': [[DECONZ_DOMAIN, serial]], + 'connections': {(CONNECTION_ZIGBEE, serial)}, + 'identifiers': {(DECONZ_DOMAIN, serial)}, 'manufacturer': self._sensor.manufacturer, 'model': self._sensor.modelid, 'name': self._sensor.name, @@ -211,15 +211,15 @@ class DeconzBattery(Entity): return attr @property - def device(self): + def device_info(self): """Return a device description for device registry.""" if (self._device.uniqueid is None or self._device.uniqueid.count(':') != 7): return None serial = self._device.uniqueid.split('-', 1)[0] return { - 'connection': [[CONNECTION_ZIGBEE, serial]], - 'identifiers': [[DECONZ_DOMAIN, serial]], + 'connections': {(CONNECTION_ZIGBEE, serial)}, + 'identifiers': {(DECONZ_DOMAIN, serial)}, 'manufacturer': self._device.manufacturer, 'model': self._device.modelid, 'name': self._device.name, diff --git a/homeassistant/components/switch/deconz.py b/homeassistant/components/switch/deconz.py index 7d861e4c29c..35dbc3ef782 100644 --- a/homeassistant/components/switch/deconz.py +++ b/homeassistant/components/switch/deconz.py @@ -81,15 +81,15 @@ class DeconzSwitch(SwitchDevice): return False @property - def device(self): + def device_info(self): """Return a device description for device registry.""" if (self._switch.uniqueid is None or self._switch.uniqueid.count(':') != 7): return None serial = self._switch.uniqueid.split('-', 1)[0] return { - 'connection': [[CONNECTION_ZIGBEE, serial]], - 'identifiers': [[DECONZ_DOMAIN, serial]], + 'connections': {(CONNECTION_ZIGBEE, serial)}, + 'identifiers': {(DECONZ_DOMAIN, serial)}, 'manufacturer': self._switch.manufacturer, 'model': self._switch.modelid, 'name': self._switch.name, diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 19a6eaa62dc..31da40134a5 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -23,8 +23,9 @@ CONNECTION_ZIGBEE = 'zigbee' class DeviceEntry: """Device Registry Entry.""" - connection = attr.ib(type=list) - identifiers = attr.ib(type=list) + config_entries = attr.ib(type=set, converter=set) + connections = attr.ib(type=set, converter=set) + identifiers = attr.ib(type=set, converter=set) manufacturer = attr.ib(type=str) model = attr.ib(type=str) name = attr.ib(type=str, default=None) @@ -46,29 +47,36 @@ class DeviceRegistry: """Check if device is registered.""" for device in self.devices: if any(iden in device.identifiers for iden in identifiers) or \ - any(conn in device.connection for conn in connections): + any(conn in device.connections for conn in connections): return device return None @callback - def async_get_or_create(self, *, connection, identifiers, manufacturer, - model, name=None, sw_version=None): + def async_get_or_create(self, *, config_entry, connections, identifiers, + manufacturer, model, name=None, sw_version=None): """Get device. Create if it doesn't exist.""" - device = self.async_get_device(identifiers, connection) + if not identifiers and not connections: + return None + + device = self.async_get_device(identifiers, connections) if device is not None: + if config_entry not in device.config_entries: + device.config_entries.add(config_entry) + self.async_schedule_save() return device device = DeviceEntry( - connection=connection, + config_entries=[config_entry], + connections=connections, identifiers=identifiers, manufacturer=manufacturer, model=model, name=name, sw_version=sw_version ) - self.devices.append(device) + self.async_schedule_save() return device @@ -81,7 +89,16 @@ class DeviceRegistry: self.devices = [] return - self.devices = [DeviceEntry(**device) for device in devices['devices']] + self.devices = [DeviceEntry( + config_entries=device['config_entries'], + connections={tuple(conn) for conn in device['connections']}, + identifiers={tuple(iden) for iden in device['identifiers']}, + manufacturer=device['manufacturer'], + model=device['model'], + name=device['name'], + sw_version=device['sw_version'], + id=device['id'], + ) for device in devices['devices']] @callback def async_schedule_save(self): @@ -95,13 +112,14 @@ class DeviceRegistry: data['devices'] = [ { - 'id': entry.id, - 'connection': entry.connection, - 'identifiers': entry.identifiers, + 'config_entries': list(entry.config_entries), + 'connections': list(entry.connections), + 'identifiers': list(entry.identifiers), 'manufacturer': entry.manufacturer, 'model': entry.model, 'name': entry.name, 'sw_version': entry.sw_version, + 'id': entry.id, } for entry in self.devices ] diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 78806e65ef1..695da5bce9c 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -131,7 +131,7 @@ class Entity: return None @property - def device(self): + def device_info(self): """Return device specific attributes. Implemented by platform classes. diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index ffac68c5f07..083a2946122 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -272,15 +272,16 @@ class EntityPlatform: else: config_entry_id = None - device = entity.device - if device is not None: + device_info = entity.device_info + if config_entry_id is not None and device_info is not None: device = device_registry.async_get_or_create( - connection=device['connection'], - identifiers=device['identifiers'], - manufacturer=device['manufacturer'], - model=device['model'], - name=device.get('name'), - sw_version=device.get('sw_version')) + config_entry=config_entry_id, + connections=device_info.get('connections', []), + identifiers=device_info.get('identifiers', []), + manufacturer=device_info.get('manufacturer'), + model=device_info.get('model'), + name=device_info.get('name'), + sw_version=device_info.get('sw_version')) device_id = device.id else: device_id = None diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index f7792eb5250..b2e73071823 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -26,22 +26,73 @@ def registry(hass): async def test_get_or_create_returns_same_entry(registry): """Make sure we do not duplicate entries.""" entry = registry.async_get_or_create( - connection=[['ethernet', '12:34:56:78:90:AB:CD:EF']], - identifiers=[['bridgeid', '0123']], + config_entry='1234', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry2 = registry.async_get_or_create( - connection=[['ethernet', '11:22:33:44:55:66:77:88']], - identifiers=[['bridgeid', '0123']], + config_entry='1234', + connections={('ethernet', '11:22:33:44:55:66:77:88')}, + identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( - connection=[['ethernet', '12:34:56:78:90:AB:CD:EF']], - identifiers=[['bridgeid', '1234']], + config_entry='1234', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('bridgeid', '1234')}, manufacturer='manufacturer', model='model') assert len(registry.devices) == 1 assert entry is entry2 assert entry is entry3 - assert entry.identifiers == [['bridgeid', '0123']] + assert entry.identifiers == {('bridgeid', '0123')} + + +async def test_requirement_for_identifier_or_connection(registry): + """Make sure we do require some descriptor of device.""" + entry = registry.async_get_or_create( + config_entry='1234', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers=set(), + manufacturer='manufacturer', model='model') + entry2 = registry.async_get_or_create( + config_entry='1234', + connections=set(), + identifiers={('bridgeid', '0123')}, + manufacturer='manufacturer', model='model') + entry3 = registry.async_get_or_create( + config_entry='1234', + connections=set(), + identifiers=set(), + manufacturer='manufacturer', model='model') + + assert len(registry.devices) == 2 + assert entry + assert entry2 + assert entry3 is None + + +async def test_multiple_config_entries(registry): + """Make sure we do not get duplicate entries.""" + entry = registry.async_get_or_create( + config_entry='123', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('bridgeid', '0123')}, + manufacturer='manufacturer', model='model') + entry2 = registry.async_get_or_create( + config_entry='456', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('bridgeid', '0123')}, + manufacturer='manufacturer', model='model') + entry3 = registry.async_get_or_create( + config_entry='123', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('bridgeid', '0123')}, + manufacturer='manufacturer', model='model') + + assert len(registry.devices) == 1 + assert entry is entry2 + assert entry is entry3 + assert entry.config_entries == {'123', '456'} async def test_loading_from_storage(hass, hass_storage): @@ -51,7 +102,10 @@ async def test_loading_from_storage(hass, hass_storage): 'data': { 'devices': [ { - 'connection': [ + 'config_entries': [ + '1234' + ], + 'connections': [ [ 'Zigbee', '01.23.45.67.89' @@ -67,7 +121,7 @@ async def test_loading_from_storage(hass, hass_storage): 'manufacturer': 'manufacturer', 'model': 'model', 'name': 'name', - 'sw_version': 'version' + 'sw_version': 'version', } ] } @@ -76,7 +130,9 @@ async def test_loading_from_storage(hass, hass_storage): registry = await device_registry.async_get_registry(hass) entry = registry.async_get_or_create( - connection=[['Zigbee', '01.23.45.67.89']], - identifiers=[['serial', '12:34:56:78:90:AB:CD:EF']], + config_entry='1234', + connections={('Zigbee', '01.23.45.67.89')}, + identifiers={('serial', '12:34:56:78:90:AB:CD:EF')}, manufacturer='manufacturer', model='model') assert entry.id == 'abcdefghijklm' + assert isinstance(entry.config_entries, set)