Use default encoder when saving storage (#75319)

This commit is contained in:
J. Nick Koston 2022-07-17 07:25:19 -05:00 committed by GitHub
parent 2eebda63fd
commit 9a27f1437d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 36 deletions

View file

@ -49,13 +49,6 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
return {} if default is None else default return {} if default is None else default
def _orjson_encoder(data: Any) -> str:
"""JSON encoder that uses orjson."""
return orjson.dumps(
data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS
).decode("utf-8")
def _orjson_default_encoder(data: Any) -> str: def _orjson_default_encoder(data: Any) -> str:
"""JSON encoder that uses orjson with hass defaults.""" """JSON encoder that uses orjson with hass defaults."""
return orjson.dumps( return orjson.dumps(
@ -79,21 +72,17 @@ def save_json(
""" """
dump: Callable[[Any], Any] dump: Callable[[Any], Any]
try: try:
if encoder: # For backwards compatibility, if they pass in the
# For backwards compatibility, if they pass in the # default json encoder we use _orjson_default_encoder
# default json encoder we use _orjson_default_encoder # which is the orjson equivalent to the default encoder.
# which is the orjson equivalent to the default encoder. if encoder and encoder is not DefaultHASSJSONEncoder:
if encoder is DefaultHASSJSONEncoder:
dump = _orjson_default_encoder
json_data = _orjson_default_encoder(data)
# If they pass a custom encoder that is not the # If they pass a custom encoder that is not the
# DefaultHASSJSONEncoder, we use the slow path of json.dumps # DefaultHASSJSONEncoder, we use the slow path of json.dumps
else: dump = json.dumps
dump = json.dumps json_data = json.dumps(data, indent=2, cls=encoder)
json_data = json.dumps(data, indent=2, cls=encoder)
else: else:
dump = _orjson_encoder dump = _orjson_default_encoder
json_data = _orjson_encoder(data) json_data = _orjson_default_encoder(data)
except TypeError as error: except TypeError as error:
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}" msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data, dump=dump))}"
_LOGGER.error(msg) _LOGGER.error(msg)

View file

@ -2,6 +2,7 @@
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
import json import json
from typing import NamedTuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -13,8 +14,9 @@ from homeassistant.const import (
from homeassistant.core import CoreState from homeassistant.core import CoreState
from homeassistant.helpers import storage from homeassistant.helpers import storage
from homeassistant.util import dt from homeassistant.util import dt
from homeassistant.util.color import RGBColor
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed, async_test_home_assistant
MOCK_VERSION = 1 MOCK_VERSION = 1
MOCK_VERSION_2 = 2 MOCK_VERSION_2 = 2
@ -460,3 +462,47 @@ async def test_changing_delayed_written_data(hass, store, hass_storage):
"key": MOCK_KEY, "key": MOCK_KEY,
"data": {"hello": "world"}, "data": {"hello": "world"},
} }
async def test_saving_load_round_trip(tmpdir):
"""Test saving and loading round trip."""
loop = asyncio.get_running_loop()
hass = await async_test_home_assistant(loop)
hass.config.config_dir = await hass.async_add_executor_job(
tmpdir.mkdir, "temp_storage"
)
class NamedTupleSubclass(NamedTuple):
"""A NamedTuple subclass."""
name: str
nts = NamedTupleSubclass("a")
data = {
"named_tuple_subclass": nts,
"rgb_color": RGBColor(255, 255, 0),
"set": {1, 2, 3},
"list": [1, 2, 3],
"tuple": (1, 2, 3),
"dict_with_int": {1: 1, 2: 2},
"dict_with_named_tuple": {1: nts, 2: nts},
}
store = storage.Store(
hass, MOCK_VERSION_2, MOCK_KEY, minor_version=MOCK_MINOR_VERSION_1
)
await store.async_save(data)
load = await store.async_load()
assert load == {
"dict_with_int": {"1": 1, "2": 2},
"dict_with_named_tuple": {"1": ["a"], "2": ["a"]},
"list": [1, 2, 3],
"named_tuple_subclass": ["a"],
"rgb_color": [255, 255, 0],
"set": [1, 2, 3],
"tuple": [1, 2, 3],
}
await hass.async_stop(force=True)

View file

@ -12,7 +12,6 @@ import pytest
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
from homeassistant.helpers.template import TupleWrapper
from homeassistant.util.json import ( from homeassistant.util.json import (
SerializationError, SerializationError,
find_paths_unserializable_data, find_paths_unserializable_data,
@ -83,23 +82,15 @@ def test_overwrite_and_reload(atomic_writes):
def test_save_bad_data(): def test_save_bad_data():
"""Test error from trying to save unserializable data.""" """Test error from trying to save unserializable data."""
class CannotSerializeMe:
"""Cannot serialize this."""
with pytest.raises(SerializationError) as excinfo: with pytest.raises(SerializationError) as excinfo:
save_json("test4", {"hello": set()}) save_json("test4", {"hello": CannotSerializeMe()})
assert ( assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
"Failed to serialize to JSON: test4. Bad data at $.hello=set()(<class 'set'>" excinfo.value
in str(excinfo.value)
)
def test_save_bad_data_tuple_wrapper():
"""Test error from trying to save unserializable data."""
with pytest.raises(SerializationError) as excinfo:
save_json("test4", {"hello": TupleWrapper(("4", "5"))})
assert (
"Failed to serialize to JSON: test4. Bad data at $.hello=('4', '5')(<class 'homeassistant.helpers.template.TupleWrapper'>"
in str(excinfo.value)
) )