Terminate strings at NUL when recording states and events (#86687)
This commit is contained in:
parent
b9ffc67a44
commit
fea30c1ce9
6 changed files with 113 additions and 7 deletions
|
@ -836,7 +836,9 @@ class Recorder(threading.Thread):
|
|||
return
|
||||
|
||||
try:
|
||||
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
|
||||
shared_data_bytes = EventData.shared_data_bytes_from_event(
|
||||
event, self.dialect_name
|
||||
)
|
||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
|
||||
return
|
||||
|
@ -869,7 +871,7 @@ class Recorder(threading.Thread):
|
|||
try:
|
||||
dbstate = States.from_event(event)
|
||||
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
|
||||
event, self._exclude_attributes_by_domain
|
||||
event, self._exclude_attributes_by_domain, self.dialect_name
|
||||
)
|
||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||
_LOGGER.warning(
|
||||
|
|
|
@ -43,11 +43,12 @@ from homeassistant.helpers.json import (
|
|||
JSON_DECODE_EXCEPTIONS,
|
||||
JSON_DUMP,
|
||||
json_bytes,
|
||||
json_bytes_strip_null,
|
||||
json_loads,
|
||||
)
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
|
||||
from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect
|
||||
from .models import StatisticData, StatisticMetaData, process_timestamp
|
||||
|
||||
# SQLAlchemy Schema
|
||||
|
@ -251,8 +252,12 @@ class EventData(Base): # type: ignore[misc,valid-type]
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def shared_data_bytes_from_event(event: Event) -> bytes:
|
||||
def shared_data_bytes_from_event(
|
||||
event: Event, dialect: SupportedDialect | None
|
||||
) -> bytes:
|
||||
"""Create shared_data from an event."""
|
||||
if dialect == SupportedDialect.POSTGRESQL:
|
||||
return json_bytes_strip_null(event.data)
|
||||
return json_bytes(event.data)
|
||||
|
||||
@staticmethod
|
||||
|
@ -416,7 +421,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
|||
|
||||
@staticmethod
|
||||
def shared_attrs_bytes_from_event(
|
||||
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
|
||||
event: Event,
|
||||
exclude_attrs_by_domain: dict[str, set[str]],
|
||||
dialect: SupportedDialect | None,
|
||||
) -> bytes:
|
||||
"""Create shared_attrs from a state_changed event."""
|
||||
state: State | None = event.data.get("new_state")
|
||||
|
@ -427,6 +434,10 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
|||
exclude_attrs = (
|
||||
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
|
||||
)
|
||||
if dialect == SupportedDialect.POSTGRESQL:
|
||||
return json_bytes_strip_null(
|
||||
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
|
||||
)
|
||||
return json_bytes(
|
||||
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
|
||||
)
|
||||
|
|
|
@ -71,6 +71,40 @@ def json_bytes(data: Any) -> bytes:
|
|||
)
|
||||
|
||||
|
||||
def json_bytes_strip_null(data: Any) -> bytes:
|
||||
"""Dump json bytes after terminating strings at the first NUL."""
|
||||
|
||||
def process_dict(_dict: dict[Any, Any]) -> dict[Any, Any]:
|
||||
"""Strip NUL from items in a dict."""
|
||||
return {key: strip_null(o) for key, o in _dict.items()}
|
||||
|
||||
def process_list(_list: list[Any]) -> list[Any]:
|
||||
"""Strip NUL from items in a list."""
|
||||
return [strip_null(o) for o in _list]
|
||||
|
||||
def strip_null(obj: Any) -> Any:
|
||||
"""Strip NUL from an object."""
|
||||
if isinstance(obj, str):
|
||||
return obj.split("\0", 1)[0]
|
||||
if isinstance(obj, dict):
|
||||
return process_dict(obj)
|
||||
if isinstance(obj, list):
|
||||
return process_list(obj)
|
||||
return obj
|
||||
|
||||
# We expect null-characters to be very rare, hence try encoding first and look
|
||||
# for an escaped null-character in the output.
|
||||
result = json_bytes(data)
|
||||
if b"\\u0000" in result:
|
||||
# We work on the processed result so we don't need to worry about
|
||||
# Home Assistant extensions which allows encoding sets, tuples, etc.
|
||||
data_processed = orjson.loads(result)
|
||||
data_processed = strip_null(data_processed)
|
||||
result = json_bytes(data_processed)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def json_dumps(data: Any) -> str:
|
||||
"""Dump json string.
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declared_attr
|
|||
from sqlalchemy.orm import aliased, declarative_base, relationship
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from homeassistant.components.recorder.const import SupportedDialect
|
||||
from homeassistant.const import (
|
||||
ATTR_ATTRIBUTION,
|
||||
ATTR_RESTORED,
|
||||
|
@ -287,7 +288,9 @@ class EventData(Base): # type: ignore[misc,valid-type]
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def shared_data_bytes_from_event(event: Event) -> bytes:
|
||||
def shared_data_bytes_from_event(
|
||||
event: Event, dialect: SupportedDialect | None
|
||||
) -> bytes:
|
||||
"""Create shared_data from an event."""
|
||||
return json_bytes(event.data)
|
||||
|
||||
|
@ -438,7 +441,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
|||
|
||||
@staticmethod
|
||||
def shared_attrs_bytes_from_event(
|
||||
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
|
||||
event: Event,
|
||||
exclude_attrs_by_domain: dict[str, set[str]],
|
||||
dialect: SupportedDialect | None,
|
||||
) -> bytes:
|
||||
"""Create shared_attrs from a state_changed event."""
|
||||
state: State | None = event.data.get("new_state")
|
||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.components.recorder.const import (
|
|||
EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
|
||||
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
|
||||
KEEPALIVE_TIME,
|
||||
SupportedDialect,
|
||||
)
|
||||
from homeassistant.components.recorder.db_schema import (
|
||||
SCHEMA_VERSION,
|
||||
|
@ -223,6 +224,42 @@ async def test_saving_state(recorder_mock, hass: HomeAssistant):
|
|||
assert state == _state_with_context(hass, entity_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect_name, expected_attributes",
|
||||
(
|
||||
(SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
|
||||
(SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}),
|
||||
(SupportedDialect.SQLITE, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
|
||||
),
|
||||
)
|
||||
async def test_saving_state_with_nul(
|
||||
recorder_mock, hass: HomeAssistant, dialect_name, expected_attributes
|
||||
):
|
||||
"""Test saving and restoring a state with nul in attributes."""
|
||||
entity_id = "test.recorder"
|
||||
state = "restoring_from_db"
|
||||
attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"}
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
|
||||
):
|
||||
hass.states.async_set(entity_id, state, attributes)
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
db_states = []
|
||||
for db_state, db_state_attributes in session.query(States, StateAttributes):
|
||||
db_states.append(db_state)
|
||||
state = db_state.to_native()
|
||||
state.attributes = db_state_attributes.to_native()
|
||||
assert len(db_states) == 1
|
||||
assert db_states[0].event_id is None
|
||||
|
||||
expected = _state_with_context(hass, entity_id)
|
||||
expected.attributes = expected_attributes
|
||||
assert state == expected
|
||||
|
||||
|
||||
async def test_saving_many_states(
|
||||
async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant
|
||||
):
|
||||
|
|
|
@ -10,6 +10,7 @@ from homeassistant import core
|
|||
from homeassistant.helpers.json import (
|
||||
ExtendedJSONEncoder,
|
||||
JSONEncoder,
|
||||
json_bytes_strip_null,
|
||||
json_dumps,
|
||||
json_dumps_sorted,
|
||||
)
|
||||
|
@ -118,3 +119,19 @@ def test_json_dumps_rgb_color_subclass():
|
|||
rgb = RGBColor(4, 2, 1)
|
||||
|
||||
assert json_dumps(rgb) == "[4,2,1]"
|
||||
|
||||
|
||||
def test_json_bytes_strip_null():
|
||||
"""Test stripping nul from strings."""
|
||||
|
||||
assert json_bytes_strip_null("\0") == b'""'
|
||||
assert json_bytes_strip_null("silly\0stuff") == b'"silly"'
|
||||
assert json_bytes_strip_null(["one", "two\0", "three"]) == b'["one","two","three"]'
|
||||
assert (
|
||||
json_bytes_strip_null({"k1": "one", "k2": "two\0", "k3": "three"})
|
||||
== b'{"k1":"one","k2":"two","k3":"three"}'
|
||||
)
|
||||
assert (
|
||||
json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]])
|
||||
== b'[[{"k1":{"k2":["silly"]}}]]'
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue