Terminate strings at NUL when recording states and events (#86687)

This commit is contained in:
Erik Montnemery 2023-01-26 11:11:03 +01:00 committed by GitHub
parent b9ffc67a44
commit fea30c1ce9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 7 deletions

View file

@ -836,7 +836,9 @@ class Recorder(threading.Thread):
return return
try: try:
shared_data_bytes = EventData.shared_data_bytes_from_event(event) shared_data_bytes = EventData.shared_data_bytes_from_event(
event, self.dialect_name
)
except JSON_ENCODE_EXCEPTIONS as ex: except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex) _LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return return
@ -869,7 +871,7 @@ class Recorder(threading.Thread):
try: try:
dbstate = States.from_event(event) dbstate = States.from_event(event)
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain event, self._exclude_attributes_by_domain, self.dialect_name
) )
except JSON_ENCODE_EXCEPTIONS as ex: except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning( _LOGGER.warning(

View file

@ -43,11 +43,12 @@ from homeassistant.helpers.json import (
JSON_DECODE_EXCEPTIONS, JSON_DECODE_EXCEPTIONS,
JSON_DUMP, JSON_DUMP,
json_bytes, json_bytes,
json_bytes_strip_null,
json_loads, json_loads,
) )
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import ALL_DOMAIN_EXCLUDE_ATTRS from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect
from .models import StatisticData, StatisticMetaData, process_timestamp from .models import StatisticData, StatisticMetaData, process_timestamp
# SQLAlchemy Schema # SQLAlchemy Schema
@ -251,8 +252,12 @@ class EventData(Base): # type: ignore[misc,valid-type]
) )
@staticmethod @staticmethod
def shared_data_bytes_from_event(event: Event) -> bytes: def shared_data_bytes_from_event(
event: Event, dialect: SupportedDialect | None
) -> bytes:
"""Create shared_data from an event.""" """Create shared_data from an event."""
if dialect == SupportedDialect.POSTGRESQL:
return json_bytes_strip_null(event.data)
return json_bytes(event.data) return json_bytes(event.data)
@staticmethod @staticmethod
@ -416,7 +421,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
@staticmethod @staticmethod
def shared_attrs_bytes_from_event( def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]] event: Event,
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes: ) -> bytes:
"""Create shared_attrs from a state_changed event.""" """Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")
@ -427,6 +434,10 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
exclude_attrs = ( exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
) )
if dialect == SupportedDialect.POSTGRESQL:
return json_bytes_strip_null(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
return json_bytes( return json_bytes(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs} {k: v for k, v in state.attributes.items() if k not in exclude_attrs}
) )

View file

@ -71,6 +71,40 @@ def json_bytes(data: Any) -> bytes:
) )
def json_bytes_strip_null(data: Any) -> bytes:
"""Dump json bytes after terminating strings at the first NUL."""
def process_dict(_dict: dict[Any, Any]) -> dict[Any, Any]:
"""Strip NUL from items in a dict."""
return {key: strip_null(o) for key, o in _dict.items()}
def process_list(_list: list[Any]) -> list[Any]:
"""Strip NUL from items in a list."""
return [strip_null(o) for o in _list]
def strip_null(obj: Any) -> Any:
"""Strip NUL from an object."""
if isinstance(obj, str):
return obj.split("\0", 1)[0]
if isinstance(obj, dict):
return process_dict(obj)
if isinstance(obj, list):
return process_list(obj)
return obj
# We expect null-characters to be very rare, hence try encoding first and look
# for an escaped null-character in the output.
result = json_bytes(data)
if b"\\u0000" in result:
# We work on the processed result so we don't need to worry about
# Home Assistant extensions which allows encoding sets, tuples, etc.
data_processed = orjson.loads(result)
data_processed = strip_null(data_processed)
result = json_bytes(data_processed)
return result
def json_dumps(data: Any) -> str: def json_dumps(data: Any) -> str:
"""Dump json string. """Dump json string.

View file

@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import aliased, declarative_base, relationship from sqlalchemy.orm import aliased, declarative_base, relationship
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.components.recorder.const import SupportedDialect
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, ATTR_ATTRIBUTION,
ATTR_RESTORED, ATTR_RESTORED,
@ -287,7 +288,9 @@ class EventData(Base): # type: ignore[misc,valid-type]
) )
@staticmethod @staticmethod
def shared_data_bytes_from_event(event: Event) -> bytes: def shared_data_bytes_from_event(
event: Event, dialect: SupportedDialect | None
) -> bytes:
"""Create shared_data from an event.""" """Create shared_data from an event."""
return json_bytes(event.data) return json_bytes(event.data)
@ -438,7 +441,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
@staticmethod @staticmethod
def shared_attrs_bytes_from_event( def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]] event: Event,
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes: ) -> bytes:
"""Create shared_attrs from a state_changed event.""" """Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")

View file

@ -31,6 +31,7 @@ from homeassistant.components.recorder.const import (
EVENT_RECORDER_5MIN_STATISTICS_GENERATED, EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED, EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
KEEPALIVE_TIME, KEEPALIVE_TIME,
SupportedDialect,
) )
from homeassistant.components.recorder.db_schema import ( from homeassistant.components.recorder.db_schema import (
SCHEMA_VERSION, SCHEMA_VERSION,
@ -223,6 +224,42 @@ async def test_saving_state(recorder_mock, hass: HomeAssistant):
assert state == _state_with_context(hass, entity_id) assert state == _state_with_context(hass, entity_id)
@pytest.mark.parametrize(
"dialect_name, expected_attributes",
(
(SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
(SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}),
(SupportedDialect.SQLITE, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
),
)
async def test_saving_state_with_nul(
recorder_mock, hass: HomeAssistant, dialect_name, expected_attributes
):
"""Test saving and restoring a state with nul in attributes."""
entity_id = "test.recorder"
state = "restoring_from_db"
attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"}
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
):
hass.states.async_set(entity_id, state, attributes)
await async_wait_recording_done(hass)
with session_scope(hass=hass) as session:
db_states = []
for db_state, db_state_attributes in session.query(States, StateAttributes):
db_states.append(db_state)
state = db_state.to_native()
state.attributes = db_state_attributes.to_native()
assert len(db_states) == 1
assert db_states[0].event_id is None
expected = _state_with_context(hass, entity_id)
expected.attributes = expected_attributes
assert state == expected
async def test_saving_many_states( async def test_saving_many_states(
async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant
): ):

View file

@ -10,6 +10,7 @@ from homeassistant import core
from homeassistant.helpers.json import ( from homeassistant.helpers.json import (
ExtendedJSONEncoder, ExtendedJSONEncoder,
JSONEncoder, JSONEncoder,
json_bytes_strip_null,
json_dumps, json_dumps,
json_dumps_sorted, json_dumps_sorted,
) )
@ -118,3 +119,19 @@ def test_json_dumps_rgb_color_subclass():
rgb = RGBColor(4, 2, 1) rgb = RGBColor(4, 2, 1)
assert json_dumps(rgb) == "[4,2,1]" assert json_dumps(rgb) == "[4,2,1]"
def test_json_bytes_strip_null():
"""Test stripping nul from strings."""
assert json_bytes_strip_null("\0") == b'""'
assert json_bytes_strip_null("silly\0stuff") == b'"silly"'
assert json_bytes_strip_null(["one", "two\0", "three"]) == b'["one","two","three"]'
assert (
json_bytes_strip_null({"k1": "one", "k2": "two\0", "k3": "three"})
== b'{"k1":"one","k2":"two","k3":"three"}'
)
assert (
json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]])
== b'[[{"k1":{"k2":["silly"]}}]]'
)