Add return type to json_loads (#85672)

* Add JSON type definitions

* Sample use

* Keep mutable for a follo-up PR (avoid dead code)

* Use list/dict

* Remove JsonObjectType

* Remove reference to Union

* Cleanup

* Improve rest

* Rename json_dict => json_data

* Add docstring

* Add type hint to json_loads

* Add cast

* Move type alias to json helpers

* Cleanup

* Create and use json_loads_object

* Make error more explicit and add tests

* Use JsonObjectType in conversation

* Remove quotes
This commit is contained in:
epenet 2023-02-07 17:21:55 +01:00 committed by GitHub
parent 20b60d57f2
commit a202588fd2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 70 additions and 28 deletions

View file

@ -18,7 +18,7 @@ import yaml
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.helpers import area_registry, entity_registry, intent, template from homeassistant.helpers import area_registry, entity_registry, intent, template
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import JsonObjectType, json_loads_object
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN from .const import DOMAIN
@ -29,9 +29,9 @@ _DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
def json_load(fp: IO[str]) -> dict[str, Any]: def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents.""" """Wrap json_loads for get_intents."""
return json_loads(fp.read()) return json_loads_object(fp.read())
@dataclass @dataclass

View file

@ -14,7 +14,7 @@ from nacl.secret import SecretBox
from homeassistant.const import ATTR_DEVICE_ID, CONTENT_TYPE_JSON from homeassistant.const import ATTR_DEVICE_ID, CONTENT_TYPE_JSON
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.json import JSONEncoder, json_loads from homeassistant.helpers.json import JSONEncoder, JsonObjectType, json_loads_object
from .const import ( from .const import (
ATTR_APP_DATA, ATTR_APP_DATA,
@ -71,7 +71,7 @@ def _decrypt_payload_helper(
ciphertext: str, ciphertext: str,
get_key_bytes: Callable[[str, int], str | bytes], get_key_bytes: Callable[[str, int], str | bytes],
key_encoder, key_encoder,
) -> dict[str, str] | None: ) -> JsonObjectType | None:
"""Decrypt encrypted payload.""" """Decrypt encrypted payload."""
try: try:
keylen, decrypt = setup_decrypt(key_encoder) keylen, decrypt = setup_decrypt(key_encoder)
@ -86,12 +86,12 @@ def _decrypt_payload_helper(
key_bytes = get_key_bytes(key, keylen) key_bytes = get_key_bytes(key, keylen)
msg_bytes = decrypt(ciphertext, key_bytes) msg_bytes = decrypt(ciphertext, key_bytes)
message = json_loads(msg_bytes) message = json_loads_object(msg_bytes)
_LOGGER.debug("Successfully decrypted mobile_app payload") _LOGGER.debug("Successfully decrypted mobile_app payload")
return message return message
def _decrypt_payload(key: str | None, ciphertext: str) -> dict[str, str] | None: def _decrypt_payload(key: str | None, ciphertext: str) -> JsonObjectType | None:
"""Decrypt encrypted payload.""" """Decrypt encrypted payload."""
def get_key_bytes(key: str, keylen: int) -> str: def get_key_bytes(key: str, keylen: int) -> str:
@ -100,7 +100,7 @@ def _decrypt_payload(key: str | None, ciphertext: str) -> dict[str, str] | None:
return _decrypt_payload_helper(key, ciphertext, get_key_bytes, HexEncoder) return _decrypt_payload_helper(key, ciphertext, get_key_bytes, HexEncoder)
def _decrypt_payload_legacy(key: str | None, ciphertext: str) -> dict[str, str] | None: def _decrypt_payload_legacy(key: str | None, ciphertext: str) -> JsonObjectType | None:
"""Decrypt encrypted payload.""" """Decrypt encrypted payload."""
def get_key_bytes(key: str, keylen: int) -> bytes: def get_key_bytes(key: str, keylen: int) -> bytes:

View file

@ -18,7 +18,7 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
) )
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads_object
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.helpers.typing import DiscoveryInfoType from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.loader import async_get_mqtt from homeassistant.loader import async_get_mqtt
@ -126,7 +126,7 @@ async def async_start( # noqa: C901
if payload: if payload:
try: try:
discovery_payload = MQTTDiscoveryPayload(json_loads(payload)) discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
except ValueError: except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return return

View file

@ -47,7 +47,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import json_dumps, json_loads from homeassistant.helpers.json import json_dumps, json_loads_object
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
@ -343,7 +343,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_received(msg: ReceiveMessage) -> None: def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
values: dict[str, Any] = json_loads(msg.payload) values = json_loads_object(msg.payload)
if values["state"] == "ON": if values["state"] == "ON":
self._attr_is_on = True self._attr_is_on = True
@ -369,7 +369,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
if brightness_supported(self.supported_color_modes): if brightness_supported(self.supported_color_modes):
try: try:
self._attr_brightness = int( self._attr_brightness = int(
values["brightness"] values["brightness"] # type: ignore[operator]
/ float(self._config[CONF_BRIGHTNESS_SCALE]) / float(self._config[CONF_BRIGHTNESS_SCALE])
* 255 * 255
) )
@ -391,7 +391,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
if values["color_temp"] is None: if values["color_temp"] is None:
self._attr_color_temp = None self._attr_color_temp = None
else: else:
self._attr_color_temp = int(values["color_temp"]) self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError: except KeyError:
pass pass
except ValueError: except ValueError:
@ -402,7 +402,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
if self.supported_features and LightEntityFeature.EFFECT: if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError): with suppress(KeyError):
self._attr_effect = values["effect"] self._attr_effect = cast(str, values["effect"])
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)

View file

@ -31,7 +31,11 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads from homeassistant.helpers.json import (
JSON_DECODE_EXCEPTIONS,
json_dumps,
json_loads_object,
)
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
@ -245,7 +249,7 @@ class MqttSiren(MqttEntity, SirenEntity):
json_payload = {STATE: payload} json_payload = {STATE: payload}
else: else:
try: try:
json_payload = json_loads(payload) json_payload = json_loads_object(payload)
_LOGGER.debug( _LOGGER.debug(
( (
"JSON payload detected after processing payload '%s' on" "JSON payload detected after processing payload '%s' on"

View file

@ -1,7 +1,7 @@
"""Support for a State MQTT vacuum.""" """Support for a State MQTT vacuum."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -25,7 +25,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import json_dumps, json_loads from homeassistant.helpers.json import json_dumps, json_loads_object
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .. import subscription from .. import subscription
@ -240,12 +240,12 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_message_received(msg: ReceiveMessage) -> None: def state_message_received(msg: ReceiveMessage) -> None:
"""Handle state MQTT message.""" """Handle state MQTT message."""
payload: dict[str, Any] = json_loads(msg.payload) payload = json_loads_object(msg.payload)
if STATE in payload and ( if STATE in payload and (
payload[STATE] in POSSIBLE_STATES or payload[STATE] is None (state := payload[STATE]) in POSSIBLE_STATES or state is None
): ):
self._attr_state = ( self._attr_state = (
POSSIBLE_STATES[payload[STATE]] if payload[STATE] else None POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
) )
del payload[STATE] del payload[STATE]
self._update_state_attributes(payload) self._update_state_attributes(payload)

View file

@ -46,6 +46,7 @@ from homeassistant.helpers.json import (
json_bytes, json_bytes,
json_bytes_strip_null, json_bytes_strip_null,
json_loads, json_loads,
json_loads_object,
) )
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -211,7 +212,7 @@ class Events(Base): # type: ignore[misc,valid-type]
try: try:
return Event( return Event(
self.event_type, self.event_type,
json_loads(self.event_data) if self.event_data else {}, json_loads_object(self.event_data) if self.event_data else {},
EventOrigin(self.origin) EventOrigin(self.origin)
if self.origin if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx], else EVENT_ORIGIN_ORDER[self.origin_idx],
@ -358,7 +359,7 @@ class States(Base): # type: ignore[misc,valid-type]
parent_id=self.context_parent_id, parent_id=self.context_parent_id,
) )
try: try:
attrs = json_loads(self.attributes) if self.attributes else {} attrs = json_loads_object(self.attributes) if self.attributes else {}
except JSON_DECODE_EXCEPTIONS: except JSON_DECODE_EXCEPTIONS:
# When json_loads fails # When json_loads fails
_LOGGER.exception("Error converting row to state: %s", self) _LOGGER.exception("Error converting row to state: %s", self)

View file

@ -16,7 +16,7 @@ from homeassistant.const import (
COMPRESSED_STATE_STATE, COMPRESSED_STATE_STATE,
) )
from homeassistant.core import Context, State from homeassistant.core import Context, State
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads_object
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import SupportedDialect from .const import SupportedDialect
@ -347,7 +347,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT: if not source or source == EMPTY_JSON_OBJECT:
return {} return {}
try: try:
attr_cache[source] = attributes = json_loads(source) attr_cache[source] = attributes = json_loads_object(source)
except ValueError: except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source) _LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {} attr_cache[source] = attributes = {}

View file

@ -1,4 +1,5 @@
"""Helpers to help with encoding Home Assistant objects in JSON.""" """Helpers to help with encoding Home Assistant objects in JSON."""
from collections.abc import Callable
import datetime import datetime
import json import json
from pathlib import Path from pathlib import Path
@ -6,6 +7,13 @@ from typing import Any, Final
import orjson import orjson
JsonValueType = (
dict[str, "JsonValueType"] | list["JsonValueType"] | str | int | float | bool | None
)
"""Any data that can be returned by the standard JSON deserializing process."""
JsonObjectType = dict[str, JsonValueType]
"""Dictionary that can be returned by the standard JSON deserializing process."""
JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError) JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,) JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)
@ -132,7 +140,18 @@ def json_dumps_sorted(data: Any) -> str:
).decode("utf-8") ).decode("utf-8")
json_loads: Callable[[bytes | bytearray | memoryview | str], JsonValueType]
json_loads = orjson.loads json_loads = orjson.loads
"""Parse JSON data."""
def json_loads_object(__obj: bytes | bytearray | memoryview | str) -> JsonObjectType:
"""Parse JSON data and ensure result is a dictionary."""
value: JsonValueType = json_loads(__obj)
# Avoid isinstance overhead as we are not interested in dict subclasses
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
return value
raise ValueError(f"Expected JSON to be parsed as a dict got {type(value)}")
JSON_DUMP: Final = json_dumps JSON_DUMP: Final = json_dumps

View file

@ -259,7 +259,7 @@ async def async_get_integration_descriptions(
config_flow_path = pathlib.Path(base) / "integrations.json" config_flow_path = pathlib.Path(base) / "integrations.json"
flow = await hass.async_add_executor_job(config_flow_path.read_text) flow = await hass.async_add_executor_job(config_flow_path.read_text)
core_flows: dict[str, Any] = json_loads(flow) core_flows = cast(dict[str, Any], json_loads(flow))
custom_integrations = await async_get_custom_components(hass) custom_integrations = await async_get_custom_components(hass)
custom_flows: dict[str, Any] = { custom_flows: dict[str, Any] = {
"integration": {}, "integration": {},
@ -474,7 +474,7 @@ class Integration:
continue continue
try: try:
manifest = json_loads(manifest_path.read_text()) manifest = cast(Manifest, json_loads(manifest_path.read_text()))
except JSON_DECODE_EXCEPTIONS as err: except JSON_DECODE_EXCEPTIONS as err:
_LOGGER.error( _LOGGER.error(
"Error parsing manifest.json file at %s: %s", manifest_path, err "Error parsing manifest.json file at %s: %s", manifest_path, err

View file

@ -13,6 +13,7 @@ from homeassistant.helpers.json import (
json_bytes_strip_null, json_bytes_strip_null,
json_dumps, json_dumps,
json_dumps_sorted, json_dumps_sorted,
json_loads_object,
) )
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.color import RGBColor from homeassistant.util.color import RGBColor
@ -135,3 +136,20 @@ def test_json_bytes_strip_null() -> None:
json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]]) json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]])
== b'[[{"k1":{"k2":["silly"]}}]]' == b'[[{"k1":{"k2":["silly"]}}]]'
) )
def test_json_loads_object():
"""Test json_loads_object validates result."""
assert json_loads_object('{"c":1.2}') == {"c": 1.2}
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'list'>"
):
json_loads_object("[]")
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'bool'>"
):
json_loads_object("true")
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'NoneType'>"
):
json_loads_object("null")