diff --git a/homeassistant/components/trace/utils.py b/homeassistant/components/trace/utils.py index 7e804724c55..50d1590e4fd 100644 --- a/homeassistant/components/trace/utils.py +++ b/homeassistant/components/trace/utils.py @@ -1,9 +1,5 @@ """Helpers for script and automation tracing and debugging.""" from collections import OrderedDict -from datetime import timedelta -from typing import Any - -from homeassistant.helpers.json import JSONEncoder as HAJSONEncoder class LimitedSizeDict(OrderedDict): @@ -25,19 +21,3 @@ class LimitedSizeDict(OrderedDict): if self.size_limit is not None: while len(self) > self.size_limit: self.popitem(last=False) - - -class TraceJSONEncoder(HAJSONEncoder): - """JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" - - def default(self, o: Any) -> Any: - """Convert certain objects. - - Fall back to repr(o). - """ - if isinstance(o, timedelta): - return {"__type": str(type(o)), "total_seconds": o.total_seconds()} - try: - return super().default(o) - except TypeError: - return {"__type": str(type(o)), "repr": repr(o)} diff --git a/homeassistant/components/trace/websocket_api.py b/homeassistant/components/trace/websocket_api.py index 17f3dc7860d..8f59660e74d 100644 --- a/homeassistant/components/trace/websocket_api.py +++ b/homeassistant/components/trace/websocket_api.py @@ -11,6 +11,7 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, ) +from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.script import ( SCRIPT_BREAKPOINT_HIT, SCRIPT_DEBUG_CONTINUE_ALL, @@ -24,7 +25,6 @@ from homeassistant.helpers.script import ( ) from .const import DATA_TRACE -from .utils import TraceJSONEncoder # mypy: allow-untyped-calls, allow-untyped-defs @@ -71,7 +71,9 @@ def websocket_trace_get(hass, connection, msg): message = websocket_api.messages.result_message(msg["id"], trace) - connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False)) + connection.send_message( + json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False) + ) def get_debug_traces(hass, key): diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 301f106edcc..4045477f75e 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -1,5 +1,6 @@ """Commands part of Websocket API.""" import asyncio +import json import voluptuous as vol @@ -17,6 +18,7 @@ from homeassistant.exceptions import ( from homeassistant.helpers import config_validation as cv, entity, template from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.event import TrackTemplate, async_track_template_result +from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations @@ -417,10 +419,11 @@ async def handle_subscribe_trigger(hass, connection, msg): @callback def forward_triggers(variables, context=None): """Forward events to websocket.""" + message = messages.event_message( + msg["id"], {"variables": variables, "context": context} + ) connection.send_message( - messages.event_message( - msg["id"], {"variables": variables, "context": context} - ) + json.dumps(message, cls=ExtendedJSONEncoder, allow_nan=False) ) connection.subscriptions[msg["id"]] = ( diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index 3168310dc59..738f744194f 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -1,5 +1,5 @@ """Helpers to help with encoding Home Assistant objects in JSON.""" -from datetime import datetime +from datetime import datetime, timedelta import json from typing import Any @@ -20,3 +20,19 @@ class JSONEncoder(json.JSONEncoder): return o.as_dict() return json.JSONEncoder.default(self, o) + + +class ExtendedJSONEncoder(JSONEncoder): + """JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" + + def default(self, o: Any) -> Any: + """Convert certain objects. + + Fall back to repr(o). + """ + if isinstance(o, timedelta): + return {"__type": str(type(o)), "total_seconds": o.total_seconds()} + try: + return super().default(o) + except TypeError: + return {"__type": str(type(o)), "repr": repr(o)} diff --git a/tests/components/trace/test_utils.py b/tests/components/trace/test_utils.py deleted file mode 100644 index ce0f09bfdd8..00000000000 --- a/tests/components/trace/test_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Test trace helpers.""" -from datetime import timedelta - -from homeassistant import core -from homeassistant.components import trace -from homeassistant.util import dt as dt_util - - -def test_json_encoder(hass): - """Test the Trace JSON Encoder.""" - ha_json_enc = trace.utils.TraceJSONEncoder() - state = core.State("test.test", "hello") - - # Test serializing a datetime - now = dt_util.utcnow() - assert ha_json_enc.default(now) == now.isoformat() - - # Test serializing a timedelta - data = timedelta( - days=50, - seconds=27, - microseconds=10, - milliseconds=29000, - minutes=5, - hours=8, - weeks=2, - ) - assert ha_json_enc.default(data) == { - "__type": str(type(data)), - "total_seconds": data.total_seconds(), - } - - # Test serializing a set() - data = {"milk", "beer"} - assert sorted(ha_json_enc.default(data)) == sorted(data) - - # Test serializong object which implements as_dict - assert ha_json_enc.default(state) == state.as_dict() - - # Default method falls back to repr(o) - o = object() - assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)} diff --git a/tests/helpers/test_json.py b/tests/helpers/test_json.py index 1a68f2b8da5..076af218676 100644 --- a/tests/helpers/test_json.py +++ b/tests/helpers/test_json.py @@ -1,8 +1,10 @@ """Test Home Assistant remote methods and classes.""" +from datetime import timedelta + import pytest from homeassistant import core -from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.json import ExtendedJSONEncoder, JSONEncoder from homeassistant.util import dt as dt_util @@ -25,3 +27,39 @@ def test_json_encoder(hass): # Default method raises TypeError if non HA object with pytest.raises(TypeError): ha_json_enc.default(1) + + +def test_trace_json_encoder(hass): + """Test the Trace JSON Encoder.""" + ha_json_enc = ExtendedJSONEncoder() + state = core.State("test.test", "hello") + + # Test serializing a datetime + now = dt_util.utcnow() + assert ha_json_enc.default(now) == now.isoformat() + + # Test serializing a timedelta + data = timedelta( + days=50, + seconds=27, + microseconds=10, + milliseconds=29000, + minutes=5, + hours=8, + weeks=2, + ) + assert ha_json_enc.default(data) == { + "__type": str(type(data)), + "total_seconds": data.total_seconds(), + } + + # Test serializing a set() + data = {"milk", "beer"} + assert sorted(ha_json_enc.default(data)) == sorted(data) + + # Test serializong object which implements as_dict + assert ha_json_enc.default(state) == state.as_dict() + + # Default method falls back to repr(o) + o = object() + assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}