Strict typing for shared MQTT modules (#80913)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-10-26 13:52:34 +02:00 committed by GitHub
parent 052c673c9e
commit 2d9f39d406
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 48 deletions

View file

@ -182,7 +182,7 @@ MQTT_PUBLISH_SCHEMA = vol.All(
async def _async_setup_discovery(
hass: HomeAssistant, conf: ConfigType, config_entry
hass: HomeAssistant, conf: ConfigType, config_entry: ConfigEntry
) -> None:
"""Try to start the discovery of MQTT devices.
@ -377,7 +377,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
retain: bool = call.data[ATTR_RETAIN]
if msg_topic_template is not None:
try:
rendered_topic = template.Template(
rendered_topic: Any = template.Template(
msg_topic_template, hass
).async_render(parse_result=False)
msg_topic = valid_publish_topic(rendered_topic)
@ -620,12 +620,12 @@ def async_subscribe_connection_status(
"""Subscribe to MQTT connection changes."""
connection_status_callback_job = HassJob(connection_status_callback)
async def connected():
async def connected() -> None:
task = hass.async_run_hass_job(connection_status_callback_job, True)
if task:
await task
async def disconnected():
async def disconnected() -> None:
task = hass.async_run_hass_job(connection_status_callback_job, False)
if task:
await task
@ -636,7 +636,7 @@ def async_subscribe_connection_status(
}
@callback
def unsubscribe():
def unsubscribe() -> None:
subscriptions["connect"]()
subscriptions["disconnect"]()

View file

@ -245,7 +245,7 @@ def subscribe(
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
).result()
def remove():
def remove() -> None:
"""Remove listener convert."""
run_callback_threadsafe(hass.loop, async_remove).result()
@ -341,7 +341,7 @@ class MQTT:
self._ha_started = asyncio.Event()
self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None
self._cleanup_on_unload: list[Callable] = []
self._cleanup_on_unload: list[Callable[[], None]] = []
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
self._pending_operations: dict[int, asyncio.Event] = {}
@ -352,14 +352,14 @@ class MQTT:
else:
@callback
def ha_started(_):
def ha_started(_: Event) -> None:
self._ha_started.set()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client()
async def async_stop_mqtt(_event: Event):
async def async_stop_mqtt(_event: Event) -> None:
"""Stop MQTT component."""
await self.async_disconnect()
@ -506,9 +506,11 @@ class MQTT:
def _client_unsubscribe(topic: str) -> int:
result: int | None = None
mid: int | None = None
result, mid = self._mqttc.unsubscribe(topic)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result)
assert mid
return mid
if any(other.topic == topic for other in self.subscriptions):
@ -553,7 +555,13 @@ class MQTT:
if errors:
_raise_on_errors(errors)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
def _mqtt_on_connect(
self,
_mqttc: mqtt.Client,
_userdata: None,
_flags: dict[str, Any],
result_code: int,
) -> None:
"""On connect callback.
Resubscribe to all topics we were subscribed to and publish birth
@ -596,7 +604,7 @@ class MQTT:
and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE]
):
async def publish_birth_message(birth_message):
async def publish_birth_message(birth_message: PublishMessage) -> None:
await self._ha_started.wait() # Wait for Home Assistant to start
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
await self.async_publish(
@ -611,7 +619,9 @@ class MQTT:
publish_birth_message(birth_message), self.hass.loop
)
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: MQTTMessage
) -> None:
"""Message received callback."""
self.hass.add_job(self._mqtt_handle_message, msg)
@ -663,7 +673,13 @@ class MQTT:
)
self._mqtt_data.state_write_requests.process_write_state_requests()
def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None:
def _mqtt_on_callback(
self,
_mqttc: mqtt.Client,
_userdata: None,
mid: int,
_granted_qos: tuple[Any, ...] | None = None,
) -> None:
"""Publish / Subscribe / Unsubscribe callback."""
self.hass.add_job(self._mqtt_handle_mid, mid)
@ -679,7 +695,9 @@ class MQTT:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
def _mqtt_on_disconnect(
self, _mqttc: mqtt.Client, _userdata: None, result_code: int
) -> None:
"""Disconnected callback."""
self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED)
@ -707,7 +725,7 @@ class MQTT:
del self._pending_operations[mid]
self._pending_operations_condition.notify_all()
async def _discovery_cooldown(self):
async def _discovery_cooldown(self) -> None:
now = time.time()
# Reset discovery and subscribe cooldowns
self._mqtt_data.last_discovery = now

View file

@ -287,7 +287,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
options_config: dict[str, Any] = {}
bad_input: bool = False
def _birth_will(birt_or_will: str) -> dict:
def _birth_will(birt_or_will: str) -> dict[str, Any]:
"""Return the user input for birth or will."""
assert user_input
return {
@ -298,8 +298,11 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
}
def _validate(
field: str, values: dict[str, Any], error_code: str, schema: Callable
):
field: str,
values: dict[str, Any],
error_code: str,
schema: Callable[[Any], Any],
) -> None:
"""Validate the user input."""
nonlocal bad_input
try:
@ -679,7 +682,9 @@ def try_connection(
result: queue.Queue[bool] = queue.Queue(maxsize=1)
def on_connect(client_, userdata, flags, result_code):
def on_connect(
client_: mqtt.Client, userdata: None, flags: dict[str, Any], result_code: int
) -> None:
"""Handle connection result."""
result.put(result_code == mqtt.CONNACK_ACCEPTED)

View file

@ -287,6 +287,7 @@ async def async_get_platform_config_from_yaml(
config_yaml: ConfigType | None = None,
) -> list[ConfigType]:
"""Return a list of validated configurations for the domain."""
platform_configs: Any | None
mqtt_data = get_mqtt_data(hass)
if config_yaml is None:
config_yaml = mqtt_data.config
@ -294,6 +295,7 @@ async def async_get_platform_config_from_yaml(
return []
if not (platform_configs := config_yaml.get(platform_domain)):
return []
assert isinstance(platform_configs, list)
return platform_configs
@ -662,7 +664,9 @@ def stop_discovery_updates(
clear_discovery_hash(hass, discovery_hash)
async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict):
async def async_remove_discovery_payload(
hass: HomeAssistant, discovery_data: DiscoveryInfoType
) -> None:
"""Clear retained discovery topic in broker to avoid rediscovery after a restart of HA."""
discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC]
await async_publish(hass, discovery_topic, "", retain=True)
@ -820,7 +824,7 @@ class MqttDiscoveryUpdate(Entity):
"""Initialize the discovery update mixin."""
self._discovery_data = discovery_data
self._discovery_update = discovery_update
self._remove_discovery_updated: Callable | None = None
self._remove_discovery_updated: Callable[[], None] | None = None
self._removed_from_hass = False
if discovery_data is None:
return
@ -1169,7 +1173,7 @@ def update_device(
device_info = device_info_from_specifications(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None:
update_device_info = cast(dict, device_info)
update_device_info = cast(dict[str, Any], device_info)
update_device_info["config_entry_id"] = config_entry_id
device = device_registry.async_get_or_create(**update_device_info)

View file

@ -21,7 +21,7 @@ class EntitySubscription:
hass: HomeAssistant = attr.ib()
topic: str = attr.ib()
message_callback: MessageCallbackType = attr.ib()
subscribe_task: Coroutine | None = attr.ib()
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib()
unsubscribe_callback: Callable[[], None] | None = attr.ib()
qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8")
@ -53,7 +53,7 @@ class EntitySubscription:
hass, self.topic, self.message_callback, self.qos, self.encoding
)
async def subscribe(self):
async def subscribe(self) -> None:
"""Subscribe to a topic."""
if not self.subscribe_task:
return

View file

@ -40,58 +40,58 @@ def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
return not bool(hass.config_entries.async_entries(DOMAIN)[0].disabled_by)
def valid_topic(value: Any) -> str:
def valid_topic(topic: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
validated_topic = cv.string(topic)
try:
raw_value = value.encode("utf-8")
raw_validated_topic = validated_topic.encode("utf-8")
except UnicodeError as err:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") from err
if not raw_value:
if not raw_validated_topic:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
if len(raw_validated_topic) > 65535:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in value:
if "\0" in validated_topic:
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
if any(char <= "\u001F" for char in value):
if any(char <= "\u001F" for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
if any("\u007f" <= char <= "\u009F" for char in value):
if any("\u007f" <= char <= "\u009F" for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
if any("\ufdd0" <= char <= "\ufdef" for char in value):
if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain non-characters.")
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in value):
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.")
return value
return validated_topic
def valid_subscribe_topic(value: Any) -> str:
def valid_subscribe_topic(topic: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == "+"):
if (i > 0 and value[i - 1] != "/") or (
i < len(value) - 1 and value[i + 1] != "/"
validated_topic = valid_topic(topic)
for i in (i for i, c in enumerate(validated_topic) if c == "+"):
if (i > 0 and validated_topic[i - 1] != "/") or (
i < len(validated_topic) - 1 and validated_topic[i + 1] != "/"
):
raise vol.Invalid(
"Single-level wildcard must occupy an entire level of the filter"
)
index = value.find("#")
index = validated_topic.find("#")
if index != -1:
if index != len(value) - 1:
if index != len(validated_topic) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid(
"Multi-level wildcard must be the last "
"character in the topic filter."
)
if len(value) > 1 and value[index - 1] != "/":
if len(validated_topic) > 1 and validated_topic[index - 1] != "/":
raise vol.Invalid(
"Multi-level wildcard must be after a topic level separator."
)
return value
return validated_topic
def valid_subscribe_topic_template(value: Any) -> template.Template:
@ -104,12 +104,12 @@ def valid_subscribe_topic_template(value: Any) -> template.Template:
return tpl
def valid_publish_topic(value: Any) -> str:
def valid_publish_topic(topic: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
if "+" in value or "#" in value:
validated_topic = valid_topic(topic)
if "+" in validated_topic or "#" in validated_topic:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
return validated_topic
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))