Fix DeviceInfo configuration_url validation (#97319)

This commit is contained in:
Franck Nijhof 2023-07-27 18:57:01 +02:00 committed by GitHub
parent b92e7c5ddf
commit 737ac8c600
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 25 deletions

View file

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from urllib.parse import urlparse
import attr
from yarl import URL
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
@ -48,6 +49,8 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
RUNTIME_ONLY_ATTRS = {"suggested_area"}
CONFIGURATION_URL_SCHEMES = {"http", "https", "homeassistant"}
class DeviceEntryDisabler(StrEnum):
"""What disabled a device entry."""
@ -168,28 +171,36 @@ def _validate_device_info(
),
)
if (config_url := device_info.get("configuration_url")) is not None:
if type(config_url) is not str or urlparse(config_url).scheme not in [
"http",
"https",
"homeassistant",
]:
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
device_info,
f"invalid configuration_url '{config_url}'",
)
return device_info_type
def _validate_configuration_url(value: Any) -> str | None:
"""Validate and convert configuration_url."""
if value is None:
return None
if (
isinstance(value, URL)
and (value.scheme not in CONFIGURATION_URL_SCHEMES or not value.host)
) or (
(parsed_url := urlparse(str(value)))
and (
parsed_url.scheme not in CONFIGURATION_URL_SCHEMES
or not parsed_url.hostname
)
):
raise ValueError(f"invalid configuration_url '{value}'")
return str(value)
@attr.s(slots=True, frozen=True)
class DeviceEntry:
"""Device Registry Entry."""
area_id: str | None = attr.ib(default=None)
config_entries: set[str] = attr.ib(converter=set, factory=set)
configuration_url: str | None = attr.ib(default=None)
configuration_url: str | URL | None = attr.ib(
converter=_validate_configuration_url, default=None
)
connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
disabled_by: DeviceEntryDisabler | None = attr.ib(default=None)
entry_type: DeviceEntryType | None = attr.ib(default=None)
@ -453,7 +464,7 @@ class DeviceRegistry:
self,
*,
config_entry_id: str,
configuration_url: str | None | UndefinedType = UNDEFINED,
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED,
default_manufacturer: str | None | UndefinedType = UNDEFINED,
default_model: str | None | UndefinedType = UNDEFINED,
@ -582,7 +593,7 @@ class DeviceRegistry:
*,
add_config_entry_id: str | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED,
configuration_url: str | None | UndefinedType = UNDEFINED,
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED,
entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED,
hw_version: str | None | UndefinedType = UNDEFINED,

View file

@ -15,6 +15,7 @@ from timeit import default_timer as timer
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict, TypeVar, final
import voluptuous as vol
from yarl import URL
from homeassistant.backports.functools import cached_property
from homeassistant.config import DATA_CUSTOMIZE
@ -177,7 +178,7 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
class DeviceInfo(TypedDict, total=False):
"""Entity device information for device registry."""
configuration_url: str | None
configuration_url: str | URL | None
connections: set[tuple[str, str]]
default_manufacturer: str
default_model: str

View file

@ -1,9 +1,11 @@
"""Tests for the Device Registry."""
from contextlib import nullcontext
import time
from typing import Any
from unittest.mock import patch
import pytest
from yarl import URL
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
@ -171,7 +173,7 @@ async def test_loading_from_storage(
{
"area_id": "12345A",
"config_entries": ["1234"],
"configuration_url": "configuration_url",
"configuration_url": "https://example.com/config",
"connections": [["Zigbee", "01.23.45.67.89"]],
"disabled_by": dr.DeviceEntryDisabler.USER,
"entry_type": dr.DeviceEntryType.SERVICE,
@ -213,7 +215,7 @@ async def test_loading_from_storage(
assert entry == dr.DeviceEntry(
area_id="12345A",
config_entries={"1234"},
configuration_url="configuration_url",
configuration_url="https://example.com/config",
connections={("Zigbee", "01.23.45.67.89")},
disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE,
@ -916,7 +918,7 @@ async def test_update(
updated_entry = device_registry.async_update_device(
entry.id,
area_id="12345A",
configuration_url="configuration_url",
configuration_url="https://example.com/config",
disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE,
hw_version="hw_version",
@ -935,7 +937,7 @@ async def test_update(
assert updated_entry == dr.DeviceEntry(
area_id="12345A",
config_entries={"1234"},
configuration_url="configuration_url",
configuration_url="https://example.com/config",
connections={("mac", "12:34:56:ab:cd:ef")},
disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE,
@ -1670,3 +1672,61 @@ async def test_only_disable_device_if_all_config_entries_are_disabled(
entry1 = device_registry.async_get(entry1.id)
assert not entry1.disabled
@pytest.mark.parametrize(
("configuration_url", "expectation"),
[
("http://localhost", nullcontext()),
("http://localhost:8123", nullcontext()),
("https://example.com", nullcontext()),
("http://localhost/config", nullcontext()),
("http://localhost:8123/config", nullcontext()),
("https://example.com/config", nullcontext()),
("homeassistant://config", nullcontext()),
(URL("http://localhost"), nullcontext()),
(URL("http://localhost:8123"), nullcontext()),
(URL("https://example.com"), nullcontext()),
(URL("http://localhost/config"), nullcontext()),
(URL("http://localhost:8123/config"), nullcontext()),
(URL("https://example.com/config"), nullcontext()),
(URL("homeassistant://config"), nullcontext()),
(None, nullcontext()),
("http://", pytest.raises(ValueError)),
("https://", pytest.raises(ValueError)),
("gopher://localhost", pytest.raises(ValueError)),
("homeassistant://", pytest.raises(ValueError)),
(URL("http://"), pytest.raises(ValueError)),
(URL("https://"), pytest.raises(ValueError)),
(URL("gopher://localhost"), pytest.raises(ValueError)),
(URL("homeassistant://"), pytest.raises(ValueError)),
# Exception implements __str__
(Exception("https://example.com"), nullcontext()),
(Exception("https://"), pytest.raises(ValueError)),
(Exception(), pytest.raises(ValueError)),
],
)
async def test_device_info_configuration_url_validation(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
configuration_url: str | URL | None,
expectation,
) -> None:
"""Test configuration URL of device info is properly validated."""
with expectation:
device_registry.async_get_or_create(
config_entry_id="1234",
identifiers={("something", "1234")},
name="name",
configuration_url=configuration_url,
)
update_device = device_registry.async_get_or_create(
config_entry_id="5678",
identifiers={("something", "5678")},
name="name",
)
with expectation:
device_registry.async_update_device(
update_device.id, configuration_url=configuration_url
)

View file

@ -1857,11 +1857,6 @@ async def test_device_name_defaulting_config_entry(
"name": "bla",
"default_name": "yo",
},
# Invalid configuration URL
{
"identifiers": {("hue", "1234")},
"configuration_url": "foo://192.168.0.100/config",
},
],
)
async def test_device_type_error_checking(