Terminate strings at NUL when recording states and events (#86687)
This commit is contained in:
parent
b9ffc67a44
commit
fea30c1ce9
6 changed files with 113 additions and 7 deletions
|
@ -836,7 +836,9 @@ class Recorder(threading.Thread):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
|
shared_data_bytes = EventData.shared_data_bytes_from_event(
|
||||||
|
event, self.dialect_name
|
||||||
|
)
|
||||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||||
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
|
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
|
||||||
return
|
return
|
||||||
|
@ -869,7 +871,7 @@ class Recorder(threading.Thread):
|
||||||
try:
|
try:
|
||||||
dbstate = States.from_event(event)
|
dbstate = States.from_event(event)
|
||||||
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
|
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
|
||||||
event, self._exclude_attributes_by_domain
|
event, self._exclude_attributes_by_domain, self.dialect_name
|
||||||
)
|
)
|
||||||
except JSON_ENCODE_EXCEPTIONS as ex:
|
except JSON_ENCODE_EXCEPTIONS as ex:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
|
|
|
@ -43,11 +43,12 @@ from homeassistant.helpers.json import (
|
||||||
JSON_DECODE_EXCEPTIONS,
|
JSON_DECODE_EXCEPTIONS,
|
||||||
JSON_DUMP,
|
JSON_DUMP,
|
||||||
json_bytes,
|
json_bytes,
|
||||||
|
json_bytes_strip_null,
|
||||||
json_loads,
|
json_loads,
|
||||||
)
|
)
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
|
from .const import ALL_DOMAIN_EXCLUDE_ATTRS, SupportedDialect
|
||||||
from .models import StatisticData, StatisticMetaData, process_timestamp
|
from .models import StatisticData, StatisticMetaData, process_timestamp
|
||||||
|
|
||||||
# SQLAlchemy Schema
|
# SQLAlchemy Schema
|
||||||
|
@ -251,8 +252,12 @@ class EventData(Base): # type: ignore[misc,valid-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared_data_bytes_from_event(event: Event) -> bytes:
|
def shared_data_bytes_from_event(
|
||||||
|
event: Event, dialect: SupportedDialect | None
|
||||||
|
) -> bytes:
|
||||||
"""Create shared_data from an event."""
|
"""Create shared_data from an event."""
|
||||||
|
if dialect == SupportedDialect.POSTGRESQL:
|
||||||
|
return json_bytes_strip_null(event.data)
|
||||||
return json_bytes(event.data)
|
return json_bytes(event.data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -416,7 +421,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared_attrs_bytes_from_event(
|
def shared_attrs_bytes_from_event(
|
||||||
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
|
event: Event,
|
||||||
|
exclude_attrs_by_domain: dict[str, set[str]],
|
||||||
|
dialect: SupportedDialect | None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Create shared_attrs from a state_changed event."""
|
"""Create shared_attrs from a state_changed event."""
|
||||||
state: State | None = event.data.get("new_state")
|
state: State | None = event.data.get("new_state")
|
||||||
|
@ -427,6 +434,10 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
||||||
exclude_attrs = (
|
exclude_attrs = (
|
||||||
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
|
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
|
||||||
)
|
)
|
||||||
|
if dialect == SupportedDialect.POSTGRESQL:
|
||||||
|
return json_bytes_strip_null(
|
||||||
|
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
|
||||||
|
)
|
||||||
return json_bytes(
|
return json_bytes(
|
||||||
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
|
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
|
||||||
)
|
)
|
||||||
|
|
|
@ -71,6 +71,40 @@ def json_bytes(data: Any) -> bytes:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def json_bytes_strip_null(data: Any) -> bytes:
|
||||||
|
"""Dump json bytes after terminating strings at the first NUL."""
|
||||||
|
|
||||||
|
def process_dict(_dict: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
|
"""Strip NUL from items in a dict."""
|
||||||
|
return {key: strip_null(o) for key, o in _dict.items()}
|
||||||
|
|
||||||
|
def process_list(_list: list[Any]) -> list[Any]:
|
||||||
|
"""Strip NUL from items in a list."""
|
||||||
|
return [strip_null(o) for o in _list]
|
||||||
|
|
||||||
|
def strip_null(obj: Any) -> Any:
|
||||||
|
"""Strip NUL from an object."""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return obj.split("\0", 1)[0]
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return process_dict(obj)
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return process_list(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# We expect null-characters to be very rare, hence try encoding first and look
|
||||||
|
# for an escaped null-character in the output.
|
||||||
|
result = json_bytes(data)
|
||||||
|
if b"\\u0000" in result:
|
||||||
|
# We work on the processed result so we don't need to worry about
|
||||||
|
# Home Assistant extensions which allows encoding sets, tuples, etc.
|
||||||
|
data_processed = orjson.loads(result)
|
||||||
|
data_processed = strip_null(data_processed)
|
||||||
|
result = json_bytes(data_processed)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(data: Any) -> str:
|
def json_dumps(data: Any) -> str:
|
||||||
"""Dump json string.
|
"""Dump json string.
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declared_attr
|
||||||
from sqlalchemy.orm import aliased, declarative_base, relationship
|
from sqlalchemy.orm import aliased, declarative_base, relationship
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
from homeassistant.components.recorder.const import SupportedDialect
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ATTRIBUTION,
|
ATTR_ATTRIBUTION,
|
||||||
ATTR_RESTORED,
|
ATTR_RESTORED,
|
||||||
|
@ -287,7 +288,9 @@ class EventData(Base): # type: ignore[misc,valid-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared_data_bytes_from_event(event: Event) -> bytes:
|
def shared_data_bytes_from_event(
|
||||||
|
event: Event, dialect: SupportedDialect | None
|
||||||
|
) -> bytes:
|
||||||
"""Create shared_data from an event."""
|
"""Create shared_data from an event."""
|
||||||
return json_bytes(event.data)
|
return json_bytes(event.data)
|
||||||
|
|
||||||
|
@ -438,7 +441,9 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared_attrs_bytes_from_event(
|
def shared_attrs_bytes_from_event(
|
||||||
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
|
event: Event,
|
||||||
|
exclude_attrs_by_domain: dict[str, set[str]],
|
||||||
|
dialect: SupportedDialect | None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Create shared_attrs from a state_changed event."""
|
"""Create shared_attrs from a state_changed event."""
|
||||||
state: State | None = event.data.get("new_state")
|
state: State | None = event.data.get("new_state")
|
||||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.components.recorder.const import (
|
||||||
EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
|
EVENT_RECORDER_5MIN_STATISTICS_GENERATED,
|
||||||
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
|
EVENT_RECORDER_HOURLY_STATISTICS_GENERATED,
|
||||||
KEEPALIVE_TIME,
|
KEEPALIVE_TIME,
|
||||||
|
SupportedDialect,
|
||||||
)
|
)
|
||||||
from homeassistant.components.recorder.db_schema import (
|
from homeassistant.components.recorder.db_schema import (
|
||||||
SCHEMA_VERSION,
|
SCHEMA_VERSION,
|
||||||
|
@ -223,6 +224,42 @@ async def test_saving_state(recorder_mock, hass: HomeAssistant):
|
||||||
assert state == _state_with_context(hass, entity_id)
|
assert state == _state_with_context(hass, entity_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dialect_name, expected_attributes",
|
||||||
|
(
|
||||||
|
(SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
|
||||||
|
(SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}),
|
||||||
|
(SupportedDialect.SQLITE, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_saving_state_with_nul(
|
||||||
|
recorder_mock, hass: HomeAssistant, dialect_name, expected_attributes
|
||||||
|
):
|
||||||
|
"""Test saving and restoring a state with nul in attributes."""
|
||||||
|
entity_id = "test.recorder"
|
||||||
|
state = "restoring_from_db"
|
||||||
|
attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
|
||||||
|
):
|
||||||
|
hass.states.async_set(entity_id, state, attributes)
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
with session_scope(hass=hass) as session:
|
||||||
|
db_states = []
|
||||||
|
for db_state, db_state_attributes in session.query(States, StateAttributes):
|
||||||
|
db_states.append(db_state)
|
||||||
|
state = db_state.to_native()
|
||||||
|
state.attributes = db_state_attributes.to_native()
|
||||||
|
assert len(db_states) == 1
|
||||||
|
assert db_states[0].event_id is None
|
||||||
|
|
||||||
|
expected = _state_with_context(hass, entity_id)
|
||||||
|
expected.attributes = expected_attributes
|
||||||
|
assert state == expected
|
||||||
|
|
||||||
|
|
||||||
async def test_saving_many_states(
|
async def test_saving_many_states(
|
||||||
async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant
|
async_setup_recorder_instance: SetupRecorderInstanceT, hass: HomeAssistant
|
||||||
):
|
):
|
||||||
|
|
|
@ -10,6 +10,7 @@ from homeassistant import core
|
||||||
from homeassistant.helpers.json import (
|
from homeassistant.helpers.json import (
|
||||||
ExtendedJSONEncoder,
|
ExtendedJSONEncoder,
|
||||||
JSONEncoder,
|
JSONEncoder,
|
||||||
|
json_bytes_strip_null,
|
||||||
json_dumps,
|
json_dumps,
|
||||||
json_dumps_sorted,
|
json_dumps_sorted,
|
||||||
)
|
)
|
||||||
|
@ -118,3 +119,19 @@ def test_json_dumps_rgb_color_subclass():
|
||||||
rgb = RGBColor(4, 2, 1)
|
rgb = RGBColor(4, 2, 1)
|
||||||
|
|
||||||
assert json_dumps(rgb) == "[4,2,1]"
|
assert json_dumps(rgb) == "[4,2,1]"
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_bytes_strip_null():
|
||||||
|
"""Test stripping nul from strings."""
|
||||||
|
|
||||||
|
assert json_bytes_strip_null("\0") == b'""'
|
||||||
|
assert json_bytes_strip_null("silly\0stuff") == b'"silly"'
|
||||||
|
assert json_bytes_strip_null(["one", "two\0", "three"]) == b'["one","two","three"]'
|
||||||
|
assert (
|
||||||
|
json_bytes_strip_null({"k1": "one", "k2": "two\0", "k3": "three"})
|
||||||
|
== b'{"k1":"one","k2":"two","k3":"three"}'
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]])
|
||||||
|
== b'[[{"k1":{"k2":["silly"]}}]]'
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue