From 4b7f85518fafa7a600a8964df9268ae2e065e1ce Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 8 Oct 2018 09:30:40 +0200 Subject: [PATCH] Prevent accidental device reg override (#17136) --- homeassistant/helpers/entity_platform.py | 27 +++++++++------ tests/helpers/test_entity_platform.py | 42 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index f2913e37339..99aa10013ab 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -275,17 +275,24 @@ class EntityPlatform: device_info = entity.device_info if config_entry_id is not None and device_info is not None: + processed_dev_info = { + 'config_entry_id': config_entry_id + } + for key in ( + 'connections', + 'identifiers', + 'manufacturer', + 'model', + 'name', + 'sw_version', + 'via_hub', + ): + if key in device_info: + processed_dev_info[key] = device_info[key] + device = device_registry.async_get_or_create( - config_entry_id=config_entry_id, - connections=device_info.get('connections') or set(), - identifiers=device_info.get('identifiers') or set(), - manufacturer=device_info.get('manufacturer'), - model=device_info.get('model'), - name=device_info.get('name'), - sw_version=device_info.get('sw_version'), - via_hub=device_info.get('via_hub')) - if device: - device_id = device.id + **processed_dev_info) + device_id = device.id else: device_id = None diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 631d446d186..97d6a0f5b98 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -728,3 +728,45 @@ async def test_device_info_called(hass): assert device.name == 'test-name' assert device.sw_version == 'test-sw' assert device.hub_device_id == hub.id + + +async def test_device_info_not_overrides(hass): + """Test device info is forwarded correctly.""" + registry = await hass.helpers.device_registry.async_get_registry() + device = registry.async_get_or_create( + config_entry_id='bla', + connections={('mac', 'abcd')}, + manufacturer='test-manufacturer', + model='test-model' + ) + + assert device.manufacturer == 'test-manufacturer' + assert device.model == 'test-model' + + async def async_setup_entry(hass, config_entry, async_add_entities): + """Mock setup entry method.""" + async_add_entities([ + MockEntity(unique_id='qwer', device_info={ + 'connections': {('mac', 'abcd')}, + }), + ]) + return True + + platform = MockPlatform( + async_setup_entry=async_setup_entry + ) + config_entry = MockConfigEntry(entry_id='super-mock-id') + entity_platform = MockEntityPlatform( + hass, + platform_name=config_entry.domain, + platform=platform + ) + + assert await entity_platform.async_setup_entry(config_entry) + await hass.async_block_till_done() + + device2 = registry.async_get_device(set(), {('mac', 'abcd')}) + assert device2 is not None + assert device.id == device2.id + assert device2.manufacturer == 'test-manufacturer' + assert device2.model == 'test-model'