diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 80ec35f5f7e..b22eff150ba 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -11,6 +11,10 @@ from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback from homeassistant.helpers.event import async_call_later +from homeassistant.util.json import ( + find_paths_unserializable_data, + format_unserializable_data, +) from .auth import AuthPhase, auth_required_message from .const import ( @@ -74,15 +78,18 @@ class WebSocketHandler: try: dumped = JSON_DUMP(message) - except (ValueError, TypeError) as err: - self._logger.error( - "Unable to serialize to JSON: %s\n%s", err, message - ) + except (ValueError, TypeError): await self.wsock.send_json( error_message( message["id"], ERR_UNKNOWN_ERROR, "Invalid JSON in response" ) ) + self._logger.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(message, dump=JSON_DUMP) + ), + ) continue await self.wsock.send_str(dumped) diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index 94dc816e03c..c5da910fae1 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -4,8 +4,9 @@ import json import logging import os import tempfile -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union +from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError _LOGGER = logging.getLogger(__name__) @@ -56,7 +57,7 @@ def save_json( json_data = json.dumps(data, sort_keys=True, indent=4, cls=encoder) except TypeError: # pylint: disable=no-member - msg = f"Failed to serialize to JSON: {filename}. Bad data found at {', '.join(find_paths_unserializable_data(data))}" + msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}" _LOGGER.error(msg) raise SerializationError(msg) @@ -85,30 +86,48 @@ def save_json( _LOGGER.error("JSON replacement cleanup failed: %s", err) -def find_paths_unserializable_data(bad_data: Any) -> List[str]: +def format_unserializable_data(data: Dict[str, Any]) -> str: + """Format output of find_paths in a friendly way. + + Format is comma separated: =() + """ + return ", ".join(f"{path}={value}({type(value)}" for path, value in data.items()) + + +def find_paths_unserializable_data( + bad_data: Any, *, dump: Callable[[Any], str] = json.dumps +) -> Dict[str, Any]: """Find the paths to unserializable data. This method is slow! Only use for error handling. """ to_process = deque([(bad_data, "$")]) - invalid = [] + invalid = {} while to_process: obj, obj_path = to_process.popleft() try: - json.dumps(obj) + dump(obj) continue - except TypeError: + except (ValueError, TypeError): pass + # We convert states and events to dict so we can find bad data inside it + if isinstance(obj, State): + obj_path += f"(state: {obj.entity_id})" + obj = obj.as_dict() + elif isinstance(obj, Event): + obj_path += f"(event: {obj.event_type})" + obj = obj.as_dict() + if isinstance(obj, dict): for key, value in obj.items(): try: # Is key valid? - json.dumps({key: None}) + dump({key: None}) except TypeError: - invalid.append(f"{obj_path}") + invalid[f"{obj_path}"] = key else: # Process value to_process.append((value, f"{obj_path}.{key}")) @@ -116,6 +135,6 @@ def find_paths_unserializable_data(bad_data: Any) -> List[str]: for idx, value in enumerate(obj): to_process.append((value, f"{obj_path}[{idx}]")) else: - invalid.append(obj_path) + invalid[obj_path] = obj return invalid diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 33a019b2e70..9082337ccc8 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -64,3 +64,19 @@ async def test_pending_msg_peak(hass, mock_low_peak, hass_ws_client, caplog): assert msg.type == WSMsgType.close assert "Client unable to keep up with pending messages" in caplog.text + + +async def test_non_json_message(hass, websocket_client, caplog): + """Test trying to serialze non JSON objects.""" + bad_data = object() + hass.states.async_set("test_domain.entity", "testing", {"bad": bad_data}) + await websocket_client.send_json({"id": 5, "type": "get_states"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 5 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert ( + f"Unable to serialize to JSON. Bad data found at $.result[0](state: test_domain.entity).attributes.bad={bad_data}(" + in caplog.text + ) diff --git a/tests/util/test_json.py b/tests/util/test_json.py index ed699c8eded..258f266ff78 100644 --- a/tests/util/test_json.py +++ b/tests/util/test_json.py @@ -1,5 +1,8 @@ """Test Home Assistant json utility functions.""" -from json import JSONEncoder +from datetime import datetime +from functools import partial +from json import JSONEncoder, dumps +import math import os import sys from tempfile import mkdtemp @@ -8,6 +11,7 @@ from unittest.mock import Mock import pytest +from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError from homeassistant.util.json import ( SerializationError, @@ -77,8 +81,9 @@ def test_save_bad_data(): with pytest.raises(SerializationError) as excinfo: save_json("test4", {"hello": set()}) - assert "Failed to serialize to JSON: test4. Bad data found at $.hello" in str( - excinfo.value + assert ( + "Failed to serialize to JSON: test4. Bad data at $.hello=set()(" + in str(excinfo.value) ) @@ -109,16 +114,46 @@ def test_custom_encoder(): def test_find_unserializable_data(): """Find unserializeable data.""" - assert find_paths_unserializable_data(1) == [] - assert find_paths_unserializable_data([1, 2]) == [] - assert find_paths_unserializable_data({"something": "yo"}) == [] + assert find_paths_unserializable_data(1) == {} + assert find_paths_unserializable_data([1, 2]) == {} + assert find_paths_unserializable_data({"something": "yo"}) == {} - assert find_paths_unserializable_data({"something": set()}) == ["$.something"] - assert find_paths_unserializable_data({"something": [1, set()]}) == [ - "$.something[1]" - ] - assert find_paths_unserializable_data([1, {"bla": set(), "blub": set()}]) == [ - "$[1].bla", - "$[1].blub", - ] - assert find_paths_unserializable_data({("A",): 1}) == ["$"] + assert find_paths_unserializable_data({"something": set()}) == { + "$.something": set() + } + assert find_paths_unserializable_data({"something": [1, set()]}) == { + "$.something[1]": set() + } + assert find_paths_unserializable_data([1, {"bla": set(), "blub": set()}]) == { + "$[1].bla": set(), + "$[1].blub": set(), + } + assert find_paths_unserializable_data({("A",): 1}) == {"$": ("A",)} + assert math.isnan( + find_paths_unserializable_data( + float("nan"), dump=partial(dumps, allow_nan=False) + )["$"] + ) + + # Test custom encoder + State support. + + class MockJSONEncoder(JSONEncoder): + """Mock JSON encoder.""" + + def default(self, o): + """Mock JSON encode method.""" + if isinstance(o, datetime): + return o.isoformat() + return super().default(o) + + bad_data = object() + + assert find_paths_unserializable_data( + [State("mock_domain.mock_entity", "on", {"bad": bad_data})], + dump=partial(dumps, cls=MockJSONEncoder), + ) == {"$[0](state: mock_domain.mock_entity).attributes.bad": bad_data} + + assert find_paths_unserializable_data( + [Event("bad_event", {"bad_attribute": bad_data})], + dump=partial(dumps, cls=MockJSONEncoder), + ) == {"$[0](event: bad_event).data.bad_attribute": bad_data}