Protect state.as_dict from mutation (#65693)
This commit is contained in:
parent
0d3bbfc9a7
commit
5da923c341
14 changed files with 114 additions and 45 deletions
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, TypeVar, cast, overload
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
@ -11,6 +11,16 @@ from .const import REDACTED
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def async_redact_data(data: Mapping, to_redact: Iterable[Any]) -> dict: # type: ignore
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||||
"""Redact sensitive data in a dict."""
|
"""Redact sensitive data in a dict."""
|
||||||
|
@ -25,7 +35,7 @@ def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||||
for key, value in redacted.items():
|
for key, value in redacted.items():
|
||||||
if key in to_redact:
|
if key in to_redact:
|
||||||
redacted[key] = REDACTED
|
redacted[key] = REDACTED
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, Mapping):
|
||||||
redacted[key] = async_redact_data(value, to_redact)
|
redacted[key] = async_redact_data(value, to_redact)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
redacted[key] = [async_redact_data(item, to_redact) for item in value]
|
redacted[key] = [async_redact_data(item, to_redact) for item in value]
|
||||||
|
|
|
@ -457,7 +457,7 @@ async def _register_service(
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute_service(call: ServiceCall) -> None:
|
async def execute_service(call: ServiceCall) -> None:
|
||||||
await entry_data.client.execute_service(service, call.data) # type: ignore[arg-type]
|
await entry_data.client.execute_service(service, call.data)
|
||||||
|
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN, service_name, execute_service, vol.Schema(schema)
|
DOMAIN, service_name, execute_service, vol.Schema(schema)
|
||||||
|
|
|
@ -2,9 +2,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Mapping
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -112,8 +111,6 @@ async def async_reproduce_states(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_attr_equal(
|
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
|
||||||
) -> bool:
|
|
||||||
"""Return true if the given attributes are equal."""
|
"""Return true if the given attributes are equal."""
|
||||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||||
|
|
|
@ -2,9 +2,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Mapping
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_OPTION
|
from homeassistant.const import ATTR_ENTITY_ID, ATTR_OPTION
|
||||||
|
@ -80,8 +79,6 @@ async def async_reproduce_states(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_attr_equal(
|
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
|
||||||
) -> bool:
|
|
||||||
"""Return true if the given attributes are equal."""
|
"""Return true if the given attributes are equal."""
|
||||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||||
|
|
|
@ -546,7 +546,7 @@ class KNXModule:
|
||||||
replaced_exposure.device.name,
|
replaced_exposure.device.name,
|
||||||
)
|
)
|
||||||
replaced_exposure.shutdown()
|
replaced_exposure.shutdown()
|
||||||
exposure = create_knx_exposure(self.hass, self.xknx, call.data) # type: ignore[arg-type]
|
exposure = create_knx_exposure(self.hass, self.xknx, call.data)
|
||||||
self.service_exposures[group_address] = exposure
|
self.service_exposures[group_address] = exposure
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Service exposure_register registered exposure for '%s' - %s",
|
"Service exposure_register registered exposure for '%s' - %s",
|
||||||
|
|
|
@ -2,9 +2,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Mapping
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any, NamedTuple, cast
|
from typing import Any, NamedTuple, cast
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -213,8 +212,6 @@ async def async_reproduce_states(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_attr_equal(
|
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
|
||||||
) -> bool:
|
|
||||||
"""Return true if the given attributes are equal."""
|
"""Return true if the given attributes are equal."""
|
||||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
"""Support for Renault services."""
|
"""Support for Renault services."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -126,7 +126,7 @@ def setup_services(hass: HomeAssistant) -> None:
|
||||||
result = await proxy.vehicle.set_charge_start()
|
result = await proxy.vehicle.set_charge_start()
|
||||||
LOGGER.debug("Charge start result: %s", result)
|
LOGGER.debug("Charge start result: %s", result)
|
||||||
|
|
||||||
def get_vehicle_proxy(service_call_data: MappingProxyType) -> RenaultVehicleProxy:
|
def get_vehicle_proxy(service_call_data: Mapping) -> RenaultVehicleProxy:
|
||||||
"""Get vehicle from service_call data."""
|
"""Get vehicle from service_call data."""
|
||||||
device_registry = dr.async_get(hass)
|
device_registry = dr.async_get(hass)
|
||||||
device_id = service_call_data[ATTR_VEHICLE]
|
device_id = service_call_data[ATTR_VEHICLE]
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any, Final, cast
|
from typing import Any, Final, cast
|
||||||
|
|
||||||
from aioshelly.block_device import Block
|
from aioshelly.block_device import Block
|
||||||
|
@ -140,7 +140,7 @@ class BlockSleepingClimate(
|
||||||
self.control_result: dict[str, Any] | None = None
|
self.control_result: dict[str, Any] | None = None
|
||||||
self.device_block: Block | None = device_block
|
self.device_block: Block | None = device_block
|
||||||
self.last_state: State | None = None
|
self.last_state: State | None = None
|
||||||
self.last_state_attributes: MappingProxyType[str, Any]
|
self.last_state_attributes: Mapping[str, Any]
|
||||||
self._preset_modes: list[str] = []
|
self._preset_modes: list[str] = []
|
||||||
|
|
||||||
if self.block is not None and self.device_block is not None:
|
if self.block is not None and self.device_block is not None:
|
||||||
|
|
|
@ -24,7 +24,6 @@ import pathlib
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
@ -83,6 +82,7 @@ from .util.async_ import (
|
||||||
run_callback_threadsafe,
|
run_callback_threadsafe,
|
||||||
shutdown_run_callback_threadsafe,
|
shutdown_run_callback_threadsafe,
|
||||||
)
|
)
|
||||||
|
from .util.read_only_dict import ReadOnlyDict
|
||||||
from .util.timeout import TimeoutManager
|
from .util.timeout import TimeoutManager
|
||||||
from .util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitSystem
|
from .util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitSystem
|
||||||
|
|
||||||
|
@ -1049,12 +1049,12 @@ class State:
|
||||||
|
|
||||||
self.entity_id = entity_id.lower()
|
self.entity_id = entity_id.lower()
|
||||||
self.state = state
|
self.state = state
|
||||||
self.attributes = MappingProxyType(attributes or {})
|
self.attributes = ReadOnlyDict(attributes or {})
|
||||||
self.last_updated = last_updated or dt_util.utcnow()
|
self.last_updated = last_updated or dt_util.utcnow()
|
||||||
self.last_changed = last_changed or self.last_updated
|
self.last_changed = last_changed or self.last_updated
|
||||||
self.context = context or Context()
|
self.context = context or Context()
|
||||||
self.domain, self.object_id = split_entity_id(self.entity_id)
|
self.domain, self.object_id = split_entity_id(self.entity_id)
|
||||||
self._as_dict: dict[str, Collection[Any]] | None = None
|
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
@ -1063,7 +1063,7 @@ class State:
|
||||||
"_", " "
|
"_", " "
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Collection[Any]]:
|
def as_dict(self) -> ReadOnlyDict[str, Collection[Any]]:
|
||||||
"""Return a dict representation of the State.
|
"""Return a dict representation of the State.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -1077,14 +1077,16 @@ class State:
|
||||||
last_updated_isoformat = last_changed_isoformat
|
last_updated_isoformat = last_changed_isoformat
|
||||||
else:
|
else:
|
||||||
last_updated_isoformat = self.last_updated.isoformat()
|
last_updated_isoformat = self.last_updated.isoformat()
|
||||||
self._as_dict = {
|
self._as_dict = ReadOnlyDict(
|
||||||
"entity_id": self.entity_id,
|
{
|
||||||
"state": self.state,
|
"entity_id": self.entity_id,
|
||||||
"attributes": dict(self.attributes),
|
"state": self.state,
|
||||||
"last_changed": last_changed_isoformat,
|
"attributes": self.attributes,
|
||||||
"last_updated": last_updated_isoformat,
|
"last_changed": last_changed_isoformat,
|
||||||
"context": self.context.as_dict(),
|
"last_updated": last_updated_isoformat,
|
||||||
}
|
"context": ReadOnlyDict(self.context.as_dict()),
|
||||||
|
}
|
||||||
|
)
|
||||||
return self._as_dict
|
return self._as_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1343,7 +1345,7 @@ class StateMachine:
|
||||||
last_changed = None
|
last_changed = None
|
||||||
else:
|
else:
|
||||||
same_state = old_state.state == new_state and not force_update
|
same_state = old_state.state == new_state and not force_update
|
||||||
same_attr = old_state.attributes == MappingProxyType(attributes)
|
same_attr = old_state.attributes == attributes
|
||||||
last_changed = old_state.last_changed if same_state else None
|
last_changed = old_state.last_changed if same_state else None
|
||||||
|
|
||||||
if same_state and same_attr:
|
if same_state and same_attr:
|
||||||
|
@ -1404,7 +1406,7 @@ class ServiceCall:
|
||||||
"""Initialize a service call."""
|
"""Initialize a service call."""
|
||||||
self.domain = domain.lower()
|
self.domain = domain.lower()
|
||||||
self.service = service.lower()
|
self.service = service.lower()
|
||||||
self.data = MappingProxyType(data or {})
|
self.data = ReadOnlyDict(data or {})
|
||||||
self.context = context or Context()
|
self.context = context or Context()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|
|
@ -2,14 +2,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine, Iterable, KeysView
|
from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
import threading
|
import threading
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import slugify as unicode_slug
|
import slugify as unicode_slug
|
||||||
|
@ -53,7 +52,7 @@ def slugify(text: str | None, *, separator: str = "_") -> str:
|
||||||
|
|
||||||
def repr_helper(inp: Any) -> str:
|
def repr_helper(inp: Any) -> str:
|
||||||
"""Help creating a more readable string representation of objects."""
|
"""Help creating a more readable string representation of objects."""
|
||||||
if isinstance(inp, (dict, MappingProxyType)):
|
if isinstance(inp, Mapping):
|
||||||
return ", ".join(
|
return ", ".join(
|
||||||
f"{repr_helper(key)}={repr_helper(item)}" for key, item in inp.items()
|
f"{repr_helper(key)}={repr_helper(item)}" for key, item in inp.items()
|
||||||
)
|
)
|
||||||
|
|
23
homeassistant/util/read_only_dict.py
Normal file
23
homeassistant/util/read_only_dict.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
"""Read only dictionary."""
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
def _readonly(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Raise an exception when a read only dict is modified."""
|
||||||
|
raise RuntimeError("Cannot modify ReadOnlyDict")
|
||||||
|
|
||||||
|
|
||||||
|
Key = TypeVar("Key")
|
||||||
|
Value = TypeVar("Value")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadOnlyDict(dict[Key, Value]):
|
||||||
|
"""Read only version of dict that is compatible with dict types."""
|
||||||
|
|
||||||
|
__setitem__ = _readonly
|
||||||
|
__delitem__ = _readonly
|
||||||
|
pop = _readonly
|
||||||
|
popitem = _readonly
|
||||||
|
clear = _readonly
|
||||||
|
update = _readonly
|
||||||
|
setdefault = _readonly
|
|
@ -931,9 +931,12 @@ def mock_restore_cache(hass, states):
|
||||||
last_states = {}
|
last_states = {}
|
||||||
for state in states:
|
for state in states:
|
||||||
restored_state = state.as_dict()
|
restored_state = state.as_dict()
|
||||||
restored_state["attributes"] = json.loads(
|
restored_state = {
|
||||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
**restored_state,
|
||||||
)
|
"attributes": json.loads(
|
||||||
|
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||||
|
),
|
||||||
|
}
|
||||||
last_states[state.entity_id] = restore_state.StoredState(
|
last_states[state.entity_id] = restore_state.StoredState(
|
||||||
State.from_dict(restored_state), now
|
State.from_dict(restored_state), now
|
||||||
)
|
)
|
||||||
|
|
|
@ -39,6 +39,7 @@ from homeassistant.exceptions import (
|
||||||
ServiceNotFound,
|
ServiceNotFound,
|
||||||
)
|
)
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
from homeassistant.util.read_only_dict import ReadOnlyDict
|
||||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||||
|
|
||||||
from tests.common import async_capture_events, async_mock_service
|
from tests.common import async_capture_events, async_mock_service
|
||||||
|
@ -377,10 +378,14 @@ def test_state_as_dict():
|
||||||
"last_updated": last_time.isoformat(),
|
"last_updated": last_time.isoformat(),
|
||||||
"state": "on",
|
"state": "on",
|
||||||
}
|
}
|
||||||
assert state.as_dict() == expected
|
as_dict_1 = state.as_dict()
|
||||||
|
assert isinstance(as_dict_1, ReadOnlyDict)
|
||||||
|
assert isinstance(as_dict_1["attributes"], ReadOnlyDict)
|
||||||
|
assert isinstance(as_dict_1["context"], ReadOnlyDict)
|
||||||
|
assert as_dict_1 == expected
|
||||||
# 2nd time to verify cache
|
# 2nd time to verify cache
|
||||||
assert state.as_dict() == expected
|
assert state.as_dict() == expected
|
||||||
assert state.as_dict() is state.as_dict()
|
assert state.as_dict() is as_dict_1
|
||||||
|
|
||||||
|
|
||||||
async def test_eventbus_add_remove_listener(hass):
|
async def test_eventbus_add_remove_listener(hass):
|
||||||
|
|
36
tests/util/test_read_only_dict.py
Normal file
36
tests/util/test_read_only_dict.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
"""Test read only dictionary."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.util.read_only_dict import ReadOnlyDict
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_only_dict():
|
||||||
|
"""Test read only dictionary."""
|
||||||
|
data = ReadOnlyDict({"hello": "world"})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data["hello"] = "universe"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data["other_key"] = "universe"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data.pop("hello")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data.popitem()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data.clear()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data.update({"yo": "yo"})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
data.setdefault("yo", "yo")
|
||||||
|
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert dict(data) == {"hello": "world"}
|
||||||
|
assert json.dumps(data) == json.dumps({"hello": "world"})
|
Loading…
Add table
Add a link
Reference in a new issue