Differentiate between device info types (#95641)

* Differentiate between device info types

* Update allowed fields

* Update homeassistant/helpers/entity_platform.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Split up message in 2 lines

* Use dict for device info types

* Extract device info function and test error checking

* Simplify parsing device info

* move checks around

* Simplify more

* Move error checking around

* Fix order

* fallback config entry title to domain

* Remove fallback for name to config entry domain

* Ensure mocked configuration URLs are strings

* one more test case

* Apply suggestions from code review

Co-authored-by: Erik Montnemery <erik@montnemery.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Paulus Schoutsen 2023-07-10 09:56:06 -04:00 committed by GitHub
parent af22a90b3a
commit eee8566694
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 133 deletions

View file

@ -30,7 +30,6 @@ from homeassistant.core import (
from homeassistant.exceptions import ( from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
PlatformNotReady, PlatformNotReady,
RequiredParameterMissing,
) )
from homeassistant.generated import languages from homeassistant.generated import languages
from homeassistant.setup import async_start_setup from homeassistant.setup import async_start_setup
@ -43,14 +42,13 @@ from . import (
service, service,
translation, translation,
) )
from .device_registry import DeviceRegistry
from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider
from .event import async_call_later, async_track_time_interval from .event import async_call_later, async_track_time_interval
from .issue_registry import IssueSeverity, async_create_issue from .issue_registry import IssueSeverity, async_create_issue
from .typing import UNDEFINED, ConfigType, DiscoveryInfoType from .typing import UNDEFINED, ConfigType, DiscoveryInfoType
if TYPE_CHECKING: if TYPE_CHECKING:
from .entity import Entity from .entity import DeviceInfo, Entity
SLOW_SETUP_WARNING = 10 SLOW_SETUP_WARNING = 10
@ -62,6 +60,37 @@ PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform" DATA_ENTITY_PLATFORM = "entity_platform"
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
DEVICE_INFO_TYPES = {
# Device info is categorized by finding the first device info type which has all
# the keys of the device info. The link device info type must be kept first
# to make it preferred over primary.
"link": {
"connections",
"identifiers",
},
"primary": {
"configuration_url",
"connections",
"entry_type",
"hw_version",
"identifiers",
"manufacturer",
"model",
"name",
"suggested_area",
"sw_version",
"via_device",
},
"secondary": {
"connections",
"default_manufacturer",
"default_model",
"default_name",
# Used by Fritz
"via_device",
},
}
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@ -497,12 +526,9 @@ class EntityPlatform:
hass = self.hass hass = self.hass
device_registry = dev_reg.async_get(hass)
entity_registry = ent_reg.async_get(hass) entity_registry = ent_reg.async_get(hass)
tasks = [ tasks = [
self._async_add_entity( self._async_add_entity(entity, update_before_add, entity_registry)
entity, update_before_add, entity_registry, device_registry
)
for entity in new_entities for entity in new_entities
] ]
@ -564,7 +590,6 @@ class EntityPlatform:
entity: Entity, entity: Entity,
update_before_add: bool, update_before_add: bool,
entity_registry: EntityRegistry, entity_registry: EntityRegistry,
device_registry: DeviceRegistry,
) -> None: ) -> None:
"""Add an entity to the platform.""" """Add an entity to the platform."""
if entity is None: if entity is None:
@ -620,68 +645,10 @@ class EntityPlatform:
entity.add_to_platform_abort() entity.add_to_platform_abort()
return return
device_info = entity.device_info if self.config_entry and (device_info := entity.device_info):
device_id = None device = self._async_process_device_info(device_info)
device = None else:
device = None
if self.config_entry and device_info is not None:
processed_dev_info: dict[str, str | None] = {}
for key in (
"connections",
"default_manufacturer",
"default_model",
"default_name",
"entry_type",
"identifiers",
"manufacturer",
"model",
"name",
"suggested_area",
"sw_version",
"hw_version",
"via_device",
):
if key in device_info:
processed_dev_info[key] = device_info[
key # type: ignore[literal-required]
]
if (
# device info that is purely meant for linking doesn't need default name
any(
key not in {"identifiers", "connections"}
for key in (processed_dev_info)
)
and "default_name" not in processed_dev_info
and not processed_dev_info.get("name")
):
processed_dev_info["name"] = self.config_entry.title
if "configuration_url" in device_info:
if device_info["configuration_url"] is None:
processed_dev_info["configuration_url"] = None
else:
configuration_url = str(device_info["configuration_url"])
if urlparse(configuration_url).scheme in [
"http",
"https",
"homeassistant",
]:
processed_dev_info["configuration_url"] = configuration_url
else:
_LOGGER.warning(
"Ignoring invalid device configuration_url '%s'",
configuration_url,
)
try:
device = device_registry.async_get_or_create(
config_entry_id=self.config_entry.entry_id,
**processed_dev_info, # type: ignore[arg-type]
)
device_id = device.id
except RequiredParameterMissing:
pass
# An entity may suggest the entity_id by setting entity_id itself # An entity may suggest the entity_id by setting entity_id itself
suggested_entity_id: str | None = entity.entity_id suggested_entity_id: str | None = entity.entity_id
@ -716,7 +683,7 @@ class EntityPlatform:
entity.unique_id, entity.unique_id,
capabilities=entity.capability_attributes, capabilities=entity.capability_attributes,
config_entry=self.config_entry, config_entry=self.config_entry,
device_id=device_id, device_id=device.id if device else None,
disabled_by=disabled_by, disabled_by=disabled_by,
entity_category=entity.entity_category, entity_category=entity.entity_category,
get_initial_options=entity.get_initial_entity_options, get_initial_options=entity.get_initial_entity_options,
@ -806,6 +773,62 @@ class EntityPlatform:
await entity.add_to_platform_finish() await entity.add_to_platform_finish()
@callback
def _async_process_device_info(
self, device_info: DeviceInfo
) -> dev_reg.DeviceEntry | None:
"""Process a device info."""
keys = set(device_info)
# If no keys or not enough info to match up, abort
if len(keys & {"connections", "identifiers"}) == 0:
self.logger.error(
"Ignoring device info without identifiers or connections: %s",
device_info,
)
return None
device_info_type: str | None = None
# Find the first device info type which has all keys in the device info
for possible_type, allowed_keys in DEVICE_INFO_TYPES.items():
if keys <= allowed_keys:
device_info_type = possible_type
break
if device_info_type is None:
self.logger.error(
"Device info for %s needs to either describe a device, "
"link to existing device or provide extra information.",
device_info,
)
return None
if (config_url := device_info.get("configuration_url")) is not None:
if type(config_url) is not str or urlparse(config_url).scheme not in [
"http",
"https",
"homeassistant",
]:
self.logger.error(
"Ignoring device info with invalid configuration_url '%s'",
config_url,
)
return None
assert self.config_entry is not None
if device_info_type == "primary" and not device_info.get("name"):
device_info = {
**device_info, # type: ignore[misc]
"name": self.config_entry.title,
}
return dev_reg.async_get(self.hass).async_get_or_create(
config_entry_id=self.config_entry.entry_id,
**device_info,
)
async def async_reset(self) -> None: async def async_reset(self) -> None:
"""Remove all entities and reset data. """Remove all entities and reset data.

View file

@ -10,4 +10,5 @@ def fritz_fixture() -> Mock:
with patch("homeassistant.components.fritzbox.Fritzhome") as fritz, patch( with patch("homeassistant.components.fritzbox.Fritzhome") as fritz, patch(
"homeassistant.components.fritzbox.config_flow.Fritzhome" "homeassistant.components.fritzbox.config_flow.Fritzhome"
): ):
fritz.return_value.get_prefixed_host.return_value = "http://1.2.3.4"
yield fritz yield fritz

View file

@ -114,6 +114,7 @@ def create_mock_client() -> Mock:
mock_client.instances = [ mock_client.instances = [
{"friendly_name": "Test instance 1", "instance": 0, "running": True} {"friendly_name": "Test instance 1", "instance": 0, "running": True}
] ]
mock_client.remote_url = f"http://{TEST_HOST}:{TEST_PORT_UI}"
return mock_client return mock_client

View file

@ -19,6 +19,7 @@ def api_fixture(get_sensors_response):
"""Define a fixture to return a mocked aiopurple API object.""" """Define a fixture to return a mocked aiopurple API object."""
return Mock( return Mock(
async_check_api_key=AsyncMock(), async_check_api_key=AsyncMock(),
get_map_url=Mock(return_value="http://example.com"),
sensors=Mock( sensors=Mock(
async_get_nearby_sensors=AsyncMock( async_get_nearby_sensors=AsyncMock(
return_value=[ return_value=[

View file

@ -1169,57 +1169,6 @@ async def test_device_info_not_overrides(hass: HomeAssistant) -> None:
assert device2.model == "test-model" assert device2.model == "test-model"
async def test_device_info_invalid_url(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test device info is forwarded correctly."""
registry = dr.async_get(hass)
registry.async_get_or_create(
config_entry_id="123",
connections=set(),
identifiers={("hue", "via-id")},
manufacturer="manufacturer",
model="via",
)
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities(
[
# Valid device info, but invalid url
MockEntity(
unique_id="qwer",
device_info={
"identifiers": {("hue", "1234")},
"configuration_url": "foo://192.168.0.100/config",
},
),
]
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(entry_id="super-mock-id")
entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
assert len(hass.states.async_entity_ids()) == 1
device = registry.async_get_device({("hue", "1234")})
assert device is not None
assert device.identifiers == {("hue", "1234")}
assert device.configuration_url is None
assert (
"Ignoring invalid device configuration_url 'foo://192.168.0.100/config'"
in caplog.text
)
async def test_device_info_homeassistant_url( async def test_device_info_homeassistant_url(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
@ -1838,28 +1787,85 @@ async def test_translated_device_class_name_influences_entity_id(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("entity_device_name", "entity_device_default_name", "expected_device_name"), (
"config_entry_title",
"entity_device_name",
"entity_device_default_name",
"expected_device_name",
),
[ [
(None, None, "Mock Config Entry Title"), ("Mock Config Entry Title", None, None, "Mock Config Entry Title"),
("", None, "Mock Config Entry Title"), ("Mock Config Entry Title", "", None, "Mock Config Entry Title"),
(None, "Hello", "Hello"), ("Mock Config Entry Title", None, "Hello", "Hello"),
("Mock Device Name", None, "Mock Device Name"), ("Mock Config Entry Title", "Mock Device Name", None, "Mock Device Name"),
], ],
) )
async def test_device_name_defaulting_config_entry( async def test_device_name_defaulting_config_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry_title: str,
entity_device_name: str, entity_device_name: str,
entity_device_default_name: str, entity_device_default_name: str,
expected_device_name: str, expected_device_name: str,
) -> None: ) -> None:
"""Test setting the device name based on input info.""" """Test setting the device name based on input info."""
device_info = { device_info = {
"identifiers": {("hue", "1234")}, "connections": {(dr.CONNECTION_NETWORK_MAC, "1234")},
"name": entity_device_name,
} }
if entity_device_default_name: if entity_device_default_name:
device_info["default_name"] = entity_device_default_name device_info["default_name"] = entity_device_default_name
else:
device_info["name"] = entity_device_name
class DeviceNameEntity(Entity):
_attr_unique_id = "qwer"
_attr_device_info = device_info
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([DeviceNameEntity()])
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(title=config_entry_title, entry_id="super-mock-id")
entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device(set(), {(dr.CONNECTION_NETWORK_MAC, "1234")})
assert device is not None
assert device.name == expected_device_name
@pytest.mark.parametrize(
("device_info"),
[
# No identifiers
{},
{"name": "bla"},
{"default_name": "bla"},
# Match multiple types
{
"identifiers": {("hue", "1234")},
"name": "bla",
"default_name": "yo",
},
# Invalid configuration URL
{
"identifiers": {("hue", "1234")},
"configuration_url": "foo://192.168.0.100/config",
},
],
)
async def test_device_type_error_checking(
hass: HomeAssistant,
device_info: dict,
) -> None:
"""Test catching invalid device info."""
class DeviceNameEntity(Entity): class DeviceNameEntity(Entity):
_attr_unique_id = "qwer" _attr_unique_id = "qwer"
@ -1879,9 +1885,9 @@ async def test_device_name_defaulting_config_entry(
) )
assert await entity_platform.async_setup_entry(config_entry) assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device({("hue", "1234")}) assert len(dev_reg.devices) == 0
assert device is not None # Entity should still be registered
assert device.name == expected_device_name ent_reg = er.async_get(hass)
assert ent_reg.async_get("test_domain.test_qwer") is not None