Use default encoder when saving storage (#75319)
This commit is contained in:
parent
2eebda63fd
commit
9a27f1437d
3 changed files with 62 additions and 36 deletions
|
@ -49,13 +49,6 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
|
||||||
return {} if default is None else default
|
return {} if default is None else default
|
||||||
|
|
||||||
|
|
||||||
def _orjson_encoder(data: Any) -> str:
|
|
||||||
"""JSON encoder that uses orjson."""
|
|
||||||
return orjson.dumps(
|
|
||||||
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def _orjson_default_encoder(data: Any) -> str:
|
def _orjson_default_encoder(data: Any) -> str:
|
||||||
"""JSON encoder that uses orjson with hass defaults."""
|
"""JSON encoder that uses orjson with hass defaults."""
|
||||||
return orjson.dumps(
|
return orjson.dumps(
|
||||||
|
@ -79,21 +72,17 @@ def save_json(
|
||||||
"""
|
"""
|
||||||
dump: Callable[[Any], Any]
|
dump: Callable[[Any], Any]
|
||||||
try:
|
try:
|
||||||
if encoder:
|
# For backwards compatibility, if they pass in the
|
||||||
# For backwards compatibility, if they pass in the
|
# default json encoder we use _orjson_default_encoder
|
||||||
# default json encoder we use _orjson_default_encoder
|
# which is the orjson equivalent to the default encoder.
|
||||||
# which is the orjson equivalent to the default encoder.
|
if encoder and encoder is not DefaultHASSJSONEncoder:
|
||||||
if encoder is DefaultHASSJSONEncoder:
|
|
||||||
dump = _orjson_default_encoder
|
|
||||||
json_data = _orjson_default_encoder(data)
|
|
||||||
# If they pass a custom encoder that is not the
|
# If they pass a custom encoder that is not the
|
||||||
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
|
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
|
||||||
else:
|
dump = json.dumps
|
||||||
dump = json.dumps
|
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
|
||||||
else:
|
else:
|
||||||
dump = _orjson_encoder
|
dump = _orjson_default_encoder
|
||||||
json_data = _orjson_encoder(data)
|
json_data = _orjson_default_encoder(data)
|
||||||
except TypeError as error:
|
except TypeError as error:
|
||||||
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}"
|
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}"
|
||||||
_LOGGER.error(msg)
|
_LOGGER.error(msg)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import json
|
import json
|
||||||
|
from typing import NamedTuple
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -13,8 +14,9 @@ from homeassistant.const import (
|
||||||
from homeassistant.core import CoreState
|
from homeassistant.core import CoreState
|
||||||
from homeassistant.helpers import storage
|
from homeassistant.helpers import storage
|
||||||
from homeassistant.util import dt
|
from homeassistant.util import dt
|
||||||
|
from homeassistant.util.color import RGBColor
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed
|
from tests.common import async_fire_time_changed, async_test_home_assistant
|
||||||
|
|
||||||
MOCK_VERSION = 1
|
MOCK_VERSION = 1
|
||||||
MOCK_VERSION_2 = 2
|
MOCK_VERSION_2 = 2
|
||||||
|
@ -460,3 +462,47 @@ async def test_changing_delayed_written_data(hass, store, hass_storage):
|
||||||
"key": MOCK_KEY,
|
"key": MOCK_KEY,
|
||||||
"data": {"hello": "world"},
|
"data": {"hello": "world"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_saving_load_round_trip(tmpdir):
|
||||||
|
"""Test saving and loading round trip."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
hass = await async_test_home_assistant(loop)
|
||||||
|
|
||||||
|
hass.config.config_dir = await hass.async_add_executor_job(
|
||||||
|
tmpdir.mkdir, "temp_storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
class NamedTupleSubclass(NamedTuple):
|
||||||
|
"""A NamedTuple subclass."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
nts = NamedTupleSubclass("a")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"named_tuple_subclass": nts,
|
||||||
|
"rgb_color": RGBColor(255, 255, 0),
|
||||||
|
"set": {1, 2, 3},
|
||||||
|
"list": [1, 2, 3],
|
||||||
|
"tuple": (1, 2, 3),
|
||||||
|
"dict_with_int": {1: 1, 2: 2},
|
||||||
|
"dict_with_named_tuple": {1: nts, 2: nts},
|
||||||
|
}
|
||||||
|
|
||||||
|
store = storage.Store(
|
||||||
|
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
|
||||||
|
)
|
||||||
|
await store.async_save(data)
|
||||||
|
load = await store.async_load()
|
||||||
|
assert load == {
|
||||||
|
"dict_with_int": {"1": 1, "2": 2},
|
||||||
|
"dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
|
||||||
|
"list": [1, 2, 3],
|
||||||
|
"named_tuple_subclass": ["a"],
|
||||||
|
"rgb_color": [255, 255, 0],
|
||||||
|
"set": [1, 2, 3],
|
||||||
|
"tuple": [1, 2, 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
await hass.async_stop(force=True)
|
||||||
|
|
|
@ -12,7 +12,6 @@ import pytest
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
|
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
|
||||||
from homeassistant.helpers.template import TupleWrapper
|
|
||||||
from homeassistant.util.json import (
|
from homeassistant.util.json import (
|
||||||
SerializationError,
|
SerializationError,
|
||||||
find_paths_unserializable_data,
|
find_paths_unserializable_data,
|
||||||
|
@ -83,23 +82,15 @@ def test_overwrite_and_reload(atomic_writes):
|
||||||
|
|
||||||
def test_save_bad_data():
|
def test_save_bad_data():
|
||||||
"""Test error from trying to save unserializable data."""
|
"""Test error from trying to save unserializable data."""
|
||||||
|
|
||||||
|
class CannotSerializeMe:
|
||||||
|
"""Cannot serialize this."""
|
||||||
|
|
||||||
with pytest.raises(SerializationError) as excinfo:
|
with pytest.raises(SerializationError) as excinfo:
|
||||||
save_json("test4", {"hello": set()})
|
save_json("test4", {"hello": CannotSerializeMe()})
|
||||||
|
|
||||||
assert (
|
assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
|
||||||
"Failed to serialize to JSON: test4. Bad data at $.hello=set()(<class 'set'>"
|
excinfo.value
|
||||||
in str(excinfo.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_bad_data_tuple_wrapper():
|
|
||||||
"""Test error from trying to save unserializable data."""
|
|
||||||
with pytest.raises(SerializationError) as excinfo:
|
|
||||||
save_json("test4", {"hello": TupleWrapper(("4", "5"))})
|
|
||||||
|
|
||||||
assert (
|
|
||||||
"Failed to serialize to JSON: test4. Bad data at $.hello=('4', '5')(<class 'homeassistant.helpers.template.TupleWrapper'>"
|
|
||||||
in str(excinfo.value)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue