Strict typing for shared MQTT modules (#80913)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
052c673c9e
commit
2d9f39d406
6 changed files with 75 additions and 48 deletions
|
@ -40,58 +40,58 @@ def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
|
|||
return not bool(hass.config_entries.async_entries(DOMAIN)[0].disabled_by)
|
||||
|
||||
|
||||
def valid_topic(value: Any) -> str:
|
||||
def valid_topic(topic: Any) -> str:
|
||||
"""Validate that this is a valid topic name/filter."""
|
||||
value = cv.string(value)
|
||||
validated_topic = cv.string(topic)
|
||||
try:
|
||||
raw_value = value.encode("utf-8")
|
||||
raw_validated_topic = validated_topic.encode("utf-8")
|
||||
except UnicodeError as err:
|
||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") from err
|
||||
if not raw_value:
|
||||
if not raw_validated_topic:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||
if len(raw_value) > 65535:
|
||||
if len(raw_validated_topic) > 65535:
|
||||
raise vol.Invalid(
|
||||
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
|
||||
)
|
||||
if "\0" in value:
|
||||
if "\0" in validated_topic:
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
|
||||
if any(char <= "\u001F" for char in value):
|
||||
if any(char <= "\u001F" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
|
||||
if any("\u007f" <= char <= "\u009F" for char in value):
|
||||
if any("\u007f" <= char <= "\u009F" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
|
||||
if any("\ufdd0" <= char <= "\ufdef" for char in value):
|
||||
if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain non-characters.")
|
||||
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in value):
|
||||
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic):
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.")
|
||||
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any) -> str:
|
||||
def valid_subscribe_topic(topic: Any) -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
for i in (i for i, c in enumerate(value) if c == "+"):
|
||||
if (i > 0 and value[i - 1] != "/") or (
|
||||
i < len(value) - 1 and value[i + 1] != "/"
|
||||
validated_topic = valid_topic(topic)
|
||||
for i in (i for i, c in enumerate(validated_topic) if c == "+"):
|
||||
if (i > 0 and validated_topic[i - 1] != "/") or (
|
||||
i < len(validated_topic) - 1 and validated_topic[i + 1] != "/"
|
||||
):
|
||||
raise vol.Invalid(
|
||||
"Single-level wildcard must occupy an entire level of the filter"
|
||||
)
|
||||
|
||||
index = value.find("#")
|
||||
index = validated_topic.find("#")
|
||||
if index != -1:
|
||||
if index != len(value) - 1:
|
||||
if index != len(validated_topic) - 1:
|
||||
# If there are multiple wildcards, this will also trigger
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be the last "
|
||||
"character in the topic filter."
|
||||
)
|
||||
if len(value) > 1 and value[index - 1] != "/":
|
||||
if len(validated_topic) > 1 and validated_topic[index - 1] != "/":
|
||||
raise vol.Invalid(
|
||||
"Multi-level wildcard must be after a topic level separator."
|
||||
)
|
||||
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
def valid_subscribe_topic_template(value: Any) -> template.Template:
|
||||
|
@ -104,12 +104,12 @@ def valid_subscribe_topic_template(value: Any) -> template.Template:
|
|||
return tpl
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
def valid_publish_topic(topic: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
if "+" in value or "#" in value:
|
||||
validated_topic = valid_topic(topic)
|
||||
if "+" in validated_topic or "#" in validated_topic:
|
||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||
return value
|
||||
return validated_topic
|
||||
|
||||
|
||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue