Prevent accidental device reg override (#17136)
This commit is contained in:
parent
59d78b060f
commit
4b7f85518f
2 changed files with 59 additions and 10 deletions
|
@ -275,17 +275,24 @@ class EntityPlatform:
|
||||||
device_info = entity.device_info
|
device_info = entity.device_info
|
||||||
|
|
||||||
if config_entry_id is not None and device_info is not None:
|
if config_entry_id is not None and device_info is not None:
|
||||||
|
processed_dev_info = {
|
||||||
|
'config_entry_id': config_entry_id
|
||||||
|
}
|
||||||
|
for key in (
|
||||||
|
'connections',
|
||||||
|
'identifiers',
|
||||||
|
'manufacturer',
|
||||||
|
'model',
|
||||||
|
'name',
|
||||||
|
'sw_version',
|
||||||
|
'via_hub',
|
||||||
|
):
|
||||||
|
if key in device_info:
|
||||||
|
processed_dev_info[key] = device_info[key]
|
||||||
|
|
||||||
device = device_registry.async_get_or_create(
|
device = device_registry.async_get_or_create(
|
||||||
config_entry_id=config_entry_id,
|
**processed_dev_info)
|
||||||
connections=device_info.get('connections') or set(),
|
device_id = device.id
|
||||||
identifiers=device_info.get('identifiers') or set(),
|
|
||||||
manufacturer=device_info.get('manufacturer'),
|
|
||||||
model=device_info.get('model'),
|
|
||||||
name=device_info.get('name'),
|
|
||||||
sw_version=device_info.get('sw_version'),
|
|
||||||
via_hub=device_info.get('via_hub'))
|
|
||||||
if device:
|
|
||||||
device_id = device.id
|
|
||||||
else:
|
else:
|
||||||
device_id = None
|
device_id = None
|
||||||
|
|
||||||
|
|
|
@ -728,3 +728,45 @@ async def test_device_info_called(hass):
|
||||||
assert device.name == 'test-name'
|
assert device.name == 'test-name'
|
||||||
assert device.sw_version == 'test-sw'
|
assert device.sw_version == 'test-sw'
|
||||||
assert device.hub_device_id == hub.id
|
assert device.hub_device_id == hub.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_info_not_overrides(hass):
|
||||||
|
"""Test device info is forwarded correctly."""
|
||||||
|
registry = await hass.helpers.device_registry.async_get_registry()
|
||||||
|
device = registry.async_get_or_create(
|
||||||
|
config_entry_id='bla',
|
||||||
|
connections={('mac', 'abcd')},
|
||||||
|
manufacturer='test-manufacturer',
|
||||||
|
model='test-model'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert device.manufacturer == 'test-manufacturer'
|
||||||
|
assert device.model == 'test-model'
|
||||||
|
|
||||||
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
|
"""Mock setup entry method."""
|
||||||
|
async_add_entities([
|
||||||
|
MockEntity(unique_id='qwer', device_info={
|
||||||
|
'connections': {('mac', 'abcd')},
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
return True
|
||||||
|
|
||||||
|
platform = MockPlatform(
|
||||||
|
async_setup_entry=async_setup_entry
|
||||||
|
)
|
||||||
|
config_entry = MockConfigEntry(entry_id='super-mock-id')
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass,
|
||||||
|
platform_name=config_entry.domain,
|
||||||
|
platform=platform
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await entity_platform.async_setup_entry(config_entry)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
device2 = registry.async_get_device(set(), {('mac', 'abcd')})
|
||||||
|
assert device2 is not None
|
||||||
|
assert device.id == device2.id
|
||||||
|
assert device2.manufacturer == 'test-manufacturer'
|
||||||
|
assert device2.model == 'test-model'
|
||||||
|
|
Loading…
Add table
Reference in a new issue