diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 20a98af4b2f..a97eed8eff6 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -836,7 +836,9 @@ class Recorder(threading.Thread): return try: - shared_data_bytes = EventData.shared_data_bytes_from_event(event) + shared_data_bytes = EventData.shared_data_bytes_from_event( + event, self.dialect_name + ) except JSON_ENCODE_EXCEPTIONS as ex: _LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex) return @@ -869,7 +871,7 @@ class Recorder(threading.Thread): try: dbstate = States.from_event(event) shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( - event, self._exclude_attributes_by_domain + event, self._exclude_attributes_by_domain, self.dialect_name ) except JSON_ENCODE_EXCEPTIONS as ex: _LOGGER.warning( diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 1b5ac87c24a..47b9658b053 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -43,11 +43,12 @@ from homeassistant.helpers.json import ( JSON_DECODE_EXCEPTIONS, JSON_DUMP, json_bytes, + json_bytes_strip_null, json_loads, ) import homeassistant.util.dt as dt_util -from .const import ALL_DOMAIN_EXCLUDE_ATTRS +from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect from .models import StatisticData, StatisticMetaData, process_timestamp # SQLAlchemy Schema @@ -251,8 +252,12 @@ class EventData(Base): # type: ignore[misc,valid-type] ) @staticmethod - def shared_data_bytes_from_event(event: Event) -> bytes: + def shared_data_bytes_from_event( + event: Event, dialect: SupportedDialect | None + ) -> bytes: """Create shared_data from an event.""" + if dialect == SupportedDialect.POSTGRESQL: + return json_bytes_strip_null(event.data) return json_bytes(event.data) @staticmethod @@ -416,7 +421,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] @staticmethod def shared_attrs_bytes_from_event( - event: Event, exclude_attrs_by_domain: dict[str, set[str]] + event: Event, + exclude_attrs_by_domain: dict[str, set[str]], + dialect: SupportedDialect | None, ) -> bytes: """Create shared_attrs from a state_changed event.""" state: State | None = event.data.get("new_state") @@ -427,6 +434,10 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] exclude_attrs = ( exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS ) + if dialect == SupportedDialect.POSTGRESQL: + return json_bytes_strip_null( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) return json_bytes( {k: v for k, v in state.attributes.items() if k not in exclude_attrs} ) diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index 74a2f542910..2a499dc0d97 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -71,6 +71,40 @@ def json_bytes(data: Any) -> bytes: ) +def json_bytes_strip_null(data: Any) -> bytes: + """Dump json bytes after terminating strings at the first NUL.""" + + def process_dict(_dict: dict[Any, Any]) -> dict[Any, Any]: + """Strip NUL from items in a dict.""" + return {key: strip_null(o) for key, o in _dict.items()} + + def process_list(_list: list[Any]) -> list[Any]: + """Strip NUL from items in a list.""" + return [strip_null(o) for o in _list] + + def strip_null(obj: Any) -> Any: + """Strip NUL from an object.""" + if isinstance(obj, str): + return obj.split("\0", 1)[0] + if isinstance(obj, dict): + return process_dict(obj) + if isinstance(obj, list): + return process_list(obj) + return obj + + # We expect null-characters to be very rare, hence try encoding first and look + # for an escaped null-character in the output. + result = json_bytes(data) + if b"\\u0000" in result: + # We work on the processed result so we don't need to worry about + # Home Assistant extensions which allows encoding sets, tuples, etc. + data_processed = orjson.loads(result) + data_processed = strip_null(data_processed) + result = json_bytes(data_processed) + + return result + + def json_dumps(data: Any) -> str: """Dump json string. diff --git a/tests/components/recorder/db_schema_30.py b/tests/components/recorder/db_schema_30.py index 8854cd33a61..01c31807ff7 100644 --- a/tests/components/recorder/db_schema_30.py +++ b/tests/components/recorder/db_schema_30.py @@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import aliased, declarative_base, relationship from sqlalchemy.orm.session import Session +from homeassistant.components.recorder.const import SupportedDialect from homeassistant.const import ( ATTR_ATTRIBUTION, ATTR_RESTORED, @@ -287,7 +288,9 @@ class EventData(Base): # type: ignore[misc,valid-type] ) @staticmethod - def shared_data_bytes_from_event(event: Event) -> bytes: + def shared_data_bytes_from_event( + event: Event, dialect: SupportedDialect | None + ) -> bytes: """Create shared_data from an event.""" return json_bytes(event.data) @@ -438,7 +441,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] @staticmethod def shared_attrs_bytes_from_event( - event: Event, exclude_attrs_by_domain: dict[str, set[str]] + event: Event, + exclude_attrs_by_domain: dict[str, set[str]], + dialect: SupportedDialect | None, ) -> bytes: """Create shared_attrs from a state_changed event.""" state: State | None = event.data.get("new_state") diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 8f32cfb6a62..c06865fb5a3 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -31,6 +31,7 @@ from homeassistant.components.recorder.const import ( EVENT_RECORDER_5MIN_STATISTICS_GENERATED, EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, KEEPALIVE_TIME, + SupportedDialect, ) from homeassistant.components.recorder.db_schema import ( SCHEMA_VERSION, @@ -223,6 +224,42 @@ async def test_saving_state(recorder_mock, hass: HomeAssistant): assert state == _state_with_context(hass, entity_id) +@pytest.mark.parametrize( + "dialect_name, expected_attributes", + ( + (SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}), + (SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}), + (SupportedDialect.SQLITE, {"test_attr": 5, "test_attr_10": "silly\0stuff"}), + ), +) +async def test_saving_state_with_nul( + recorder_mock, hass: HomeAssistant, dialect_name, expected_attributes +): + """Test saving and restoring a state with nul in attributes.""" + entity_id = "test.recorder" + state = "restoring_from_db" + attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"} + + with patch( + "homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name + ): + hass.states.async_set(entity_id, state, attributes) + await async_wait_recording_done(hass) + + with session_scope(hass=hass) as session: + db_states = [] + for db_state, db_state_attributes in session.query(States, StateAttributes): + db_states.append(db_state) + state = db_state.to_native() + state.attributes = db_state_attributes.to_native() + assert len(db_states) == 1 + assert db_states[0].event_id is None + + expected = _state_with_context(hass, entity_id) + expected.attributes = expected_attributes + assert state == expected + + async def test_saving_many_states( async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant ): diff --git a/tests/helpers/test_json.py b/tests/helpers/test_json.py index 1e85338f152..92583fcfba8 100644 --- a/tests/helpers/test_json.py +++ b/tests/helpers/test_json.py @@ -10,6 +10,7 @@ from homeassistant import core from homeassistant.helpers.json import ( ExtendedJSONEncoder, JSONEncoder, + json_bytes_strip_null, json_dumps, json_dumps_sorted, ) @@ -118,3 +119,19 @@ def test_json_dumps_rgb_color_subclass(): rgb = RGBColor(4, 2, 1) assert json_dumps(rgb) == "[4,2,1]" + + +def test_json_bytes_strip_null(): + """Test stripping nul from strings.""" + + assert json_bytes_strip_null("\0") == b'""' + assert json_bytes_strip_null("silly\0stuff") == b'"silly"' + assert json_bytes_strip_null(["one", "two\0", "three"]) == b'["one","two","three"]' + assert ( + json_bytes_strip_null({"k1": "one", "k2": "two\0", "k3": "three"}) + == b'{"k1":"one","k2":"two","k3":"three"}' + ) + assert ( + json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]]) + == b'[[{"k1":{"k2":["silly"]}}]]' + )