diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index 27acff54f99..77301532d3d 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -24,10 +24,10 @@ from homeassistant.components.recorder.statistics import ( ) from homeassistant.components.recorder.util import session_scope from homeassistant.components.websocket_api import messages -from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA +from homeassistant.helpers.json import JSON_DUMP from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util diff --git a/homeassistant/components/logbook/websocket_api.py b/homeassistant/components/logbook/websocket_api.py index 0bb7877b95b..7265bcbae86 100644 --- a/homeassistant/components/logbook/websocket_api.py +++ b/homeassistant/components/logbook/websocket_api.py @@ -14,9 +14,9 @@ from homeassistant.components import websocket_api from homeassistant.components.recorder import get_instance from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api.connection import ActiveConnection -from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.helpers.event import async_track_point_in_utc_time +from homeassistant.helpers.json import JSON_DUMP import homeassistant.util.dt as dt_util from .helpers import ( diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index e558d19b530..e94092d2154 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -1,12 +1,11 @@ """Recorder constants.""" -from functools import partial -import json -from typing import Final from homeassistant.backports.enum import StrEnum from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES -from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import + JSON_DUMP, +) DATA_INSTANCE = "recorder_instance" SQLITE_URL_PREFIX = "sqlite://" @@ -27,7 +26,6 @@ MAX_ROWS_TO_PURGE = 998 DB_WORKER_PREFIX = "DbWorker" -JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":")) ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES} diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 7df4cf57e56..8b15e15042f 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -744,11 +744,12 @@ class Recorder(threading.Thread): return try: - shared_data = EventData.shared_data_from_event(event) + shared_data_bytes = EventData.shared_data_bytes_from_event(event) except (TypeError, ValueError) as ex: _LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex) return + shared_data = shared_data_bytes.decode("utf-8") # Matching attributes found in the pending commit if pending_event_data := self._pending_event_data.get(shared_data): dbevent.event_data_rel = pending_event_data @@ -756,7 +757,7 @@ class Recorder(threading.Thread): elif data_id := self._event_data_ids.get(shared_data): dbevent.data_id = data_id else: - data_hash = EventData.hash_shared_data(shared_data) + data_hash = EventData.hash_shared_data_bytes(shared_data_bytes) # Matching attributes found in the database if data_id := self._find_shared_data_in_db(data_hash, shared_data): self._event_data_ids[shared_data] = dbevent.data_id = data_id @@ -775,7 +776,7 @@ class Recorder(threading.Thread): assert self.event_session is not None try: dbstate = States.from_event(event) - shared_attrs = StateAttributes.shared_attrs_from_event( + shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( event, self._exclude_attributes_by_domain ) except (TypeError, ValueError) as ex: @@ -786,6 +787,7 @@ class Recorder(threading.Thread): ) return + shared_attrs = shared_attrs_bytes.decode("utf-8") dbstate.attributes = None # Matching attributes found in the pending commit if pending_attributes := self._pending_state_attributes.get(shared_attrs): @@ -794,7 +796,7 @@ class Recorder(threading.Thread): elif attributes_id := self._state_attributes_ids.get(shared_attrs): dbstate.attributes_id = attributes_id else: - attr_hash = StateAttributes.hash_shared_attrs(shared_attrs) + attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) # Matching attributes found in the database if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs): dbstate.attributes_id = attributes_id diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 70c816c2af5..e0a22184cc8 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections.abc import Callable from datetime import datetime, timedelta -import json import logging from typing import Any, TypedDict, cast, overload import ciso8601 from fnvhash import fnv1a_32 +import orjson from sqlalchemy import ( JSON, BigInteger, @@ -46,9 +46,10 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers.json import JSON_DUMP, json_bytes import homeassistant.util.dt as dt_util -from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP +from .const import ALL_DOMAIN_EXCLUDE_ATTRS # SQLAlchemy Schema # pylint: disable=invalid-name @@ -132,7 +133,7 @@ class JSONLiteral(JSON): # type: ignore[misc] def process(value: Any) -> str: """Dump json.""" - return json.dumps(value) + return JSON_DUMP(value) return process @@ -199,7 +200,7 @@ class Events(Base): # type: ignore[misc,valid-type] try: return Event( self.event_type, - json.loads(self.event_data) if self.event_data else {}, + orjson.loads(self.event_data) if self.event_data else {}, EventOrigin(self.origin) if self.origin else EVENT_ORIGIN_ORDER[self.origin_idx], @@ -207,7 +208,7 @@ class Events(Base): # type: ignore[misc,valid-type] context=context, ) except ValueError: - # When json.loads fails + # When orjson.loads fails _LOGGER.exception("Error converting to event: %s", self) return None @@ -235,25 +236,26 @@ class EventData(Base): # type: ignore[misc,valid-type] @staticmethod def from_event(event: Event) -> EventData: """Create object from an event.""" - shared_data = JSON_DUMP(event.data) + shared_data = json_bytes(event.data) return EventData( - shared_data=shared_data, hash=EventData.hash_shared_data(shared_data) + shared_data=shared_data.decode("utf-8"), + hash=EventData.hash_shared_data_bytes(shared_data), ) @staticmethod - def shared_data_from_event(event: Event) -> str: - """Create shared_attrs from an event.""" - return JSON_DUMP(event.data) + def shared_data_bytes_from_event(event: Event) -> bytes: + """Create shared_data from an event.""" + return json_bytes(event.data) @staticmethod - def hash_shared_data(shared_data: str) -> int: + def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: """Return the hash of json encoded shared data.""" - return cast(int, fnv1a_32(shared_data.encode("utf-8"))) + return cast(int, fnv1a_32(shared_data_bytes)) def to_native(self) -> dict[str, Any]: """Convert to an HA state object.""" try: - return cast(dict[str, Any], json.loads(self.shared_data)) + return cast(dict[str, Any], orjson.loads(self.shared_data)) except ValueError: _LOGGER.exception("Error converting row to event data: %s", self) return {} @@ -340,9 +342,9 @@ class States(Base): # type: ignore[misc,valid-type] parent_id=self.context_parent_id, ) try: - attrs = json.loads(self.attributes) if self.attributes else {} + attrs = orjson.loads(self.attributes) if self.attributes else {} except ValueError: - # When json.loads fails + # When orjson.loads fails _LOGGER.exception("Error converting row to state: %s", self) return None if self.last_changed is None or self.last_changed == self.last_updated: @@ -388,40 +390,39 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] """Create object from a state_changed event.""" state: State | None = event.data.get("new_state") # None state means the state was removed from the state machine - dbstate = StateAttributes( - shared_attrs="{}" if state is None else JSON_DUMP(state.attributes) - ) - dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs) + attr_bytes = b"{}" if state is None else json_bytes(state.attributes) + dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8")) + dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes) return dbstate @staticmethod - def shared_attrs_from_event( + def shared_attrs_bytes_from_event( event: Event, exclude_attrs_by_domain: dict[str, set[str]] - ) -> str: + ) -> bytes: """Create shared_attrs from a state_changed event.""" state: State | None = event.data.get("new_state") # None state means the state was removed from the state machine if state is None: - return "{}" + return b"{}" domain = split_entity_id(state.entity_id)[0] exclude_attrs = ( exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS ) - return JSON_DUMP( + return json_bytes( {k: v for k, v in state.attributes.items() if k not in exclude_attrs} ) @staticmethod - def hash_shared_attrs(shared_attrs: str) -> int: - """Return the hash of json encoded shared attributes.""" - return cast(int, fnv1a_32(shared_attrs.encode("utf-8"))) + def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: + """Return the hash of orjson encoded shared attributes.""" + return cast(int, fnv1a_32(shared_attrs_bytes)) def to_native(self) -> dict[str, Any]: """Convert to an HA state object.""" try: - return cast(dict[str, Any], json.loads(self.shared_attrs)) + return cast(dict[str, Any], orjson.loads(self.shared_attrs)) except ValueError: - # When json.loads fails + # When orjson.loads fails _LOGGER.exception("Error converting row to state attributes: %s", self) return {} @@ -835,7 +836,7 @@ def decode_attributes_from_row( if not source or source == EMPTY_JSON_OBJECT: return {} try: - attr_cache[source] = attributes = json.loads(source) + attr_cache[source] = attributes = orjson.loads(source) except ValueError: _LOGGER.exception("Error converting row to state attributes: %s", source) attr_cache[source] = attributes = {} diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 61bcb8badf0..bea08722eb0 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -29,7 +29,7 @@ from homeassistant.helpers.event import ( TrackTemplateResult, async_track_template_result, ) -from homeassistant.helpers.json import ExtendedJSONEncoder +from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations @@ -241,13 +241,13 @@ def handle_get_states( # to succeed for the UI to show. response = messages.result_message(msg["id"], states) try: - connection.send_message(const.JSON_DUMP(response)) + connection.send_message(JSON_DUMP(response)) return except (ValueError, TypeError): connection.logger.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(response, dump=const.JSON_DUMP) + find_paths_unserializable_data(response, dump=JSON_DUMP) ), ) del response @@ -256,13 +256,13 @@ def handle_get_states( serialized = [] for state in states: try: - serialized.append(const.JSON_DUMP(state)) + serialized.append(JSON_DUMP(state)) except (ValueError, TypeError): # Error is already logged above pass # We now have partially serialized states. Craft some JSON. - response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) + response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized)) connection.send_message(response2) @@ -315,13 +315,13 @@ def handle_subscribe_entities( # to succeed for the UI to show. response = messages.event_message(msg["id"], data) try: - connection.send_message(const.JSON_DUMP(response)) + connection.send_message(JSON_DUMP(response)) return except (ValueError, TypeError): connection.logger.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(response, dump=const.JSON_DUMP) + find_paths_unserializable_data(response, dump=JSON_DUMP) ), ) del response @@ -330,14 +330,14 @@ def handle_subscribe_entities( cannot_serialize: list[str] = [] for entity_id, state_dict in add_entities.items(): try: - const.JSON_DUMP(state_dict) + JSON_DUMP(state_dict) except (ValueError, TypeError): cannot_serialize.append(entity_id) for entity_id in cannot_serialize: del add_entities[entity_id] - connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data))) + connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data))) @decorators.websocket_command({vol.Required("type"): "get_services"}) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 0280863f83e..26c4c6f8321 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -11,6 +11,7 @@ import voluptuous as vol from homeassistant.auth.models import RefreshToken, User from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized +from homeassistant.helpers.json import JSON_DUMP from . import const, messages @@ -56,7 +57,7 @@ class ActiveConnection: async def send_big_result(self, msg_id: int, result: Any) -> None: """Send a result message that would be expensive to JSON serialize.""" content = await self.hass.async_add_executor_job( - const.JSON_DUMP, messages.result_message(msg_id, result) + JSON_DUMP, messages.result_message(msg_id, result) ) self.send_message(content) diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 107cf6d0270..60a00126092 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -4,12 +4,9 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable from concurrent import futures -from functools import partial -import json from typing import TYPE_CHECKING, Any, Final from homeassistant.core import HomeAssistant -from homeassistant.helpers.json import JSONEncoder if TYPE_CHECKING: from .connection import ActiveConnection # noqa: F401 @@ -53,10 +50,6 @@ SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected" # Data used to store the current connection list DATA_CONNECTIONS: Final = f"{DOMAIN}.connections" -JSON_DUMP: Final = partial( - json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":") -) - COMPRESSED_STATE_STATE = "s" COMPRESSED_STATE_ATTRIBUTES = "a" COMPRESSED_STATE_CONTEXT = "c" diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index f546ba5eec6..c3e5f6bb5f5 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -9,6 +9,7 @@ import voluptuous as vol from homeassistant.core import Event, State from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.json import JSON_DUMP from homeassistant.util.json import ( find_paths_unserializable_data, format_unserializable_data, @@ -193,15 +194,15 @@ def compressed_state_dict_add(state: State) -> dict[str, Any]: def message_to_json(message: dict[str, Any]) -> str: """Serialize a websocket message to json.""" try: - return const.JSON_DUMP(message) + return JSON_DUMP(message) except (ValueError, TypeError): _LOGGER.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(message, dump=const.JSON_DUMP) + find_paths_unserializable_data(message, dump=JSON_DUMP) ), ) - return const.JSON_DUMP( + return JSON_DUMP( error_message( message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response" ) diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index eaabb002b0a..2e56698db41 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -14,6 +14,7 @@ from aiohttp import web from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout import async_timeout +import orjson from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ @@ -97,6 +98,7 @@ def _async_create_clientsession( """Create a new ClientSession with kwargs, i.e. for cookies.""" clientsession = aiohttp.ClientSession( connector=_async_get_connector(hass, verify_ssl), + json_serialize=lambda x: orjson.dumps(x).decode("utf-8"), **kwargs, ) # Prevent packages accidentally overriding our default headers diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index c581e5a9361..912667a13b5 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -1,7 +1,10 @@ """Helpers to help with encoding Home Assistant objects in JSON.""" import datetime import json -from typing import Any +from pathlib import Path +from typing import Any, Final + +import orjson class JSONEncoder(json.JSONEncoder): @@ -22,6 +25,20 @@ class JSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) +def json_encoder_default(obj: Any) -> Any: + """Convert Home Assistant objects. + + Hand other objects to the original method. + """ + if isinstance(obj, set): + return list(obj) + if hasattr(obj, "as_dict"): + return obj.as_dict() + if isinstance(obj, Path): + return obj.as_posix() + raise TypeError + + class ExtendedJSONEncoder(JSONEncoder): """JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" @@ -40,3 +57,31 @@ class ExtendedJSONEncoder(JSONEncoder): return super().default(o) except TypeError: return {"__type": str(type(o)), "repr": repr(o)} + + +def json_bytes(data: Any) -> bytes: + """Dump json bytes.""" + return orjson.dumps( + data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default + ) + + +def json_dumps(data: Any) -> str: + """Dump json string. + + orjson supports serializing dataclasses natively which + eliminates the need to implement as_dict in many places + when the data is already in a dataclass. This works + well as long as all the data in the dataclass can also + be serialized. + + If it turns out to be a problem we can disable this + with option |= orjson.OPT_PASSTHROUGH_DATACLASS and it + will fallback to as_dict + """ + return orjson.dumps( + data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default + ).decode("utf-8") + + +JSON_DUMP: Final = json_dumps diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index a3d8a00bcfb..c158d26a9aa 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -20,6 +20,7 @@ httpx==0.23.0 ifaddr==0.1.7 jinja2==3.1.2 lru-dict==1.1.7 +orjson==3.6.8 paho-mqtt==1.6.1 pillow==9.1.1 pip>=21.0,<22.2 diff --git a/homeassistant/scripts/benchmark/__init__.py b/homeassistant/scripts/benchmark/__init__.py index a681b3e210d..efbfec5e961 100644 --- a/homeassistant/scripts/benchmark/__init__.py +++ b/homeassistant/scripts/benchmark/__init__.py @@ -12,14 +12,13 @@ from timeit import default_timer as timer from typing import TypeVar from homeassistant import core -from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.helpers.entityfilter import convert_include_exclude_filter from homeassistant.helpers.event import ( async_track_state_change, async_track_state_change_event, ) -from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.json import JSON_DUMP, JSONEncoder # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: no-warn-return-any diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index fdee7a7a90f..82ecfd34d6d 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -7,6 +7,8 @@ import json import logging from typing import Any +import orjson + from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError @@ -30,7 +32,7 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict: """ try: with open(filename, encoding="utf-8") as fdesc: - return json.loads(fdesc.read()) # type: ignore[no-any-return] + return orjson.loads(fdesc.read()) # type: ignore[no-any-return] except FileNotFoundError: # This is not a fatal error _LOGGER.debug("JSON file not found: %s", filename) @@ -56,7 +58,10 @@ def save_json( Returns True on success. """ try: - json_data = json.dumps(data, indent=4, cls=encoder) + if encoder: + json_data = json.dumps(data, indent=2, cls=encoder) + else: + json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8") except TypeError as error: msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}" _LOGGER.error(msg) diff --git a/pyproject.toml b/pyproject.toml index cc745f58ad6..7e62bafd6af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "PyJWT==2.4.0", # PyJWT has loose dependency. We want the latest one. "cryptography==36.0.2", + "orjson==3.6.8", "pip>=21.0,<22.2", "python-slugify==4.0.1", "pyyaml==6.0", @@ -119,6 +120,7 @@ extension-pkg-allow-list = [ "av.audio.stream", "av.stream", "ciso8601", + "orjson", "cv2", ] diff --git a/requirements.txt b/requirements.txt index fe2bf87ad25..9805ae7cd47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ ifaddr==0.1.7 jinja2==3.1.2 PyJWT==2.4.0 cryptography==36.0.2 +orjson==3.6.8 pip>=21.0,<22.2 python-slugify==4.0.1 pyyaml==6.0 diff --git a/tests/components/energy/test_validate.py b/tests/components/energy/test_validate.py index 37ebe4147c5..e802688daaf 100644 --- a/tests/components/energy/test_validate.py +++ b/tests/components/energy/test_validate.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from homeassistant.components.energy import async_get_manager, validate +from homeassistant.helpers.json import JSON_DUMP from homeassistant.setup import async_setup_component @@ -408,7 +409,11 @@ async def test_validation_grid( }, ) - assert (await validate.async_validate(hass)).as_dict() == { + result = await validate.async_validate(hass) + # verify its also json serializable + JSON_DUMP(result) + + assert result.as_dict() == { "energy_sources": [ [ { diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 4d3302f7c13..0f4695596fc 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -619,12 +619,15 @@ async def test_states_filters_visible(hass, hass_admin_user, websocket_client): async def test_get_states_not_allows_nan(hass, websocket_client): - """Test get_states command not allows NaN floats.""" + """Test get_states command converts NaN to None.""" hass.states.async_set("greeting.hello", "world") hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")}) hass.states.async_set("greeting.bye", "universe") await websocket_client.send_json({"id": 5, "type": "get_states"}) + bad = dict(hass.states.get("greeting.bad").as_dict()) + bad["attributes"] = dict(bad["attributes"]) + bad["attributes"]["hello"] = None msg = await websocket_client.receive_json() assert msg["id"] == 5 @@ -632,6 +635,7 @@ async def test_get_states_not_allows_nan(hass, websocket_client): assert msg["success"] assert msg["result"] == [ hass.states.get("greeting.hello").as_dict(), + bad, hass.states.get("greeting.bye").as_dict(), ]