Fix DeviceInfo configuration_url validation (#97319)

This commit is contained in:
Franck Nijhof 2023-07-27 18:57:01 +02:00 committed by GitHub
parent b92e7c5ddf
commit 737ac8c600
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 25 deletions

View file

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import attr import attr
from yarl import URL
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
@ -48,6 +49,8 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
RUNTIME_ONLY_ATTRS = {"suggested_area"} RUNTIME_ONLY_ATTRS = {"suggested_area"}
CONFIGURATION_URL_SCHEMES = {"http", "https", "homeassistant"}
class DeviceEntryDisabler(StrEnum): class DeviceEntryDisabler(StrEnum):
"""What disabled a device entry.""" """What disabled a device entry."""
@ -168,28 +171,36 @@ def _validate_device_info(
), ),
) )
if (config_url := device_info.get("configuration_url")) is not None:
if type(config_url) is not str or urlparse(config_url).scheme not in [
"http",
"https",
"homeassistant",
]:
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
device_info,
f"invalid configuration_url '{config_url}'",
)
return device_info_type return device_info_type
def _validate_configuration_url(value: Any) -> str | None:
"""Validate and convert configuration_url."""
if value is None:
return None
if (
isinstance(value, URL)
and (value.scheme not in CONFIGURATION_URL_SCHEMES or not value.host)
) or (
(parsed_url := urlparse(str(value)))
and (
parsed_url.scheme not in CONFIGURATION_URL_SCHEMES
or not parsed_url.hostname
)
):
raise ValueError(f"invalid configuration_url '{value}'")
return str(value)
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DeviceEntry: class DeviceEntry:
"""Device Registry Entry.""" """Device Registry Entry."""
area_id: str | None = attr.ib(default=None) area_id: str | None = attr.ib(default=None)
config_entries: set[str] = attr.ib(converter=set, factory=set) config_entries: set[str] = attr.ib(converter=set, factory=set)
configuration_url: str | None = attr.ib(default=None) configuration_url: str | URL | None = attr.ib(
converter=_validate_configuration_url, default=None
)
connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set) connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set)
disabled_by: DeviceEntryDisabler | None = attr.ib(default=None) disabled_by: DeviceEntryDisabler | None = attr.ib(default=None)
entry_type: DeviceEntryType | None = attr.ib(default=None) entry_type: DeviceEntryType | None = attr.ib(default=None)
@ -453,7 +464,7 @@ class DeviceRegistry:
self, self,
*, *,
config_entry_id: str, config_entry_id: str,
configuration_url: str | None | UndefinedType = UNDEFINED, configuration_url: str | URL | None | UndefinedType = UNDEFINED,
connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED,
default_manufacturer: str | None | UndefinedType = UNDEFINED, default_manufacturer: str | None | UndefinedType = UNDEFINED,
default_model: str | None | UndefinedType = UNDEFINED, default_model: str | None | UndefinedType = UNDEFINED,
@ -582,7 +593,7 @@ class DeviceRegistry:
*, *,
add_config_entry_id: str | UndefinedType = UNDEFINED, add_config_entry_id: str | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED,
configuration_url: str | None | UndefinedType = UNDEFINED, configuration_url: str | URL | None | UndefinedType = UNDEFINED,
disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED,
entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED,
hw_version: str | None | UndefinedType = UNDEFINED, hw_version: str | None | UndefinedType = UNDEFINED,

View file

@ -15,6 +15,7 @@ from timeit import default_timer as timer
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict, TypeVar, final from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict, TypeVar, final
import voluptuous as vol import voluptuous as vol
from yarl import URL
from homeassistant.backports.functools import cached_property from homeassistant.backports.functools import cached_property
from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config import DATA_CUSTOMIZE
@ -177,7 +178,7 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
class DeviceInfo(TypedDict, total=False): class DeviceInfo(TypedDict, total=False):
"""Entity device information for device registry.""" """Entity device information for device registry."""
configuration_url: str | None configuration_url: str | URL | None
connections: set[tuple[str, str]] connections: set[tuple[str, str]]
default_manufacturer: str default_manufacturer: str
default_model: str default_model: str

View file

@ -1,9 +1,11 @@
"""Tests for the Device Registry.""" """Tests for the Device Registry."""
from contextlib import nullcontext
import time import time
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from yarl import URL
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
@ -171,7 +173,7 @@ async def test_loading_from_storage(
{ {
"area_id": "12345A", "area_id": "12345A",
"config_entries": ["1234"], "config_entries": ["1234"],
"configuration_url": "configuration_url", "configuration_url": "https://example.com/config",
"connections": [["Zigbee", "01.23.45.67.89"]], "connections": [["Zigbee", "01.23.45.67.89"]],
"disabled_by": dr.DeviceEntryDisabler.USER, "disabled_by": dr.DeviceEntryDisabler.USER,
"entry_type": dr.DeviceEntryType.SERVICE, "entry_type": dr.DeviceEntryType.SERVICE,
@ -213,7 +215,7 @@ async def test_loading_from_storage(
assert entry == dr.DeviceEntry( assert entry == dr.DeviceEntry(
area_id="12345A", area_id="12345A",
config_entries={"1234"}, config_entries={"1234"},
configuration_url="configuration_url", configuration_url="https://example.com/config",
connections={("Zigbee", "01.23.45.67.89")}, connections={("Zigbee", "01.23.45.67.89")},
disabled_by=dr.DeviceEntryDisabler.USER, disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
@ -916,7 +918,7 @@ async def test_update(
updated_entry = device_registry.async_update_device( updated_entry = device_registry.async_update_device(
entry.id, entry.id,
area_id="12345A", area_id="12345A",
configuration_url="configuration_url", configuration_url="https://example.com/config",
disabled_by=dr.DeviceEntryDisabler.USER, disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
hw_version="hw_version", hw_version="hw_version",
@ -935,7 +937,7 @@ async def test_update(
assert updated_entry == dr.DeviceEntry( assert updated_entry == dr.DeviceEntry(
area_id="12345A", area_id="12345A",
config_entries={"1234"}, config_entries={"1234"},
configuration_url="configuration_url", configuration_url="https://example.com/config",
connections={("mac", "12:34:56:ab:cd:ef")}, connections={("mac", "12:34:56:ab:cd:ef")},
disabled_by=dr.DeviceEntryDisabler.USER, disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
@ -1670,3 +1672,61 @@ async def test_only_disable_device_if_all_config_entries_are_disabled(
entry1 = device_registry.async_get(entry1.id) entry1 = device_registry.async_get(entry1.id)
assert not entry1.disabled assert not entry1.disabled
@pytest.mark.parametrize(
("configuration_url", "expectation"),
[
("http://localhost", nullcontext()),
("http://localhost:8123", nullcontext()),
("https://example.com", nullcontext()),
("http://localhost/config", nullcontext()),
("http://localhost:8123/config", nullcontext()),
("https://example.com/config", nullcontext()),
("homeassistant://config", nullcontext()),
(URL("http://localhost"), nullcontext()),
(URL("http://localhost:8123"), nullcontext()),
(URL("https://example.com"), nullcontext()),
(URL("http://localhost/config"), nullcontext()),
(URL("http://localhost:8123/config"), nullcontext()),
(URL("https://example.com/config"), nullcontext()),
(URL("homeassistant://config"), nullcontext()),
(None, nullcontext()),
("http://", pytest.raises(ValueError)),
("https://", pytest.raises(ValueError)),
("gopher://localhost", pytest.raises(ValueError)),
("homeassistant://", pytest.raises(ValueError)),
(URL("http://"), pytest.raises(ValueError)),
(URL("https://"), pytest.raises(ValueError)),
(URL("gopher://localhost"), pytest.raises(ValueError)),
(URL("homeassistant://"), pytest.raises(ValueError)),
# Exception implements __str__
(Exception("https://example.com"), nullcontext()),
(Exception("https://"), pytest.raises(ValueError)),
(Exception(), pytest.raises(ValueError)),
],
)
async def test_device_info_configuration_url_validation(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
configuration_url: str | URL | None,
expectation,
) -> None:
"""Test configuration URL of device info is properly validated."""
with expectation:
device_registry.async_get_or_create(
config_entry_id="1234",
identifiers={("something", "1234")},
name="name",
configuration_url=configuration_url,
)
update_device = device_registry.async_get_or_create(
config_entry_id="5678",
identifiers={("something", "5678")},
name="name",
)
with expectation:
device_registry.async_update_device(
update_device.id, configuration_url=configuration_url
)

View file

@ -1857,11 +1857,6 @@ async def test_device_name_defaulting_config_entry(
"name": "bla", "name": "bla",
"default_name": "yo", "default_name": "yo",
}, },
# Invalid configuration URL
{
"identifiers": {("hue", "1234")},
"configuration_url": "foo://192.168.0.100/config",
},
], ],
) )
async def test_device_type_error_checking( async def test_device_type_error_checking(