From 2d9f39d40631d95824b43c63e4ccad4130cc296b Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Wed, 26 Oct 2022 13:52:34 +0200 Subject: [PATCH] Strict typing for shared MQTT modules (#80913) Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- homeassistant/components/mqtt/__init__.py | 10 ++-- homeassistant/components/mqtt/client.py | 38 +++++++++++---- homeassistant/components/mqtt/config_flow.py | 13 +++-- homeassistant/components/mqtt/mixins.py | 10 ++-- homeassistant/components/mqtt/subscription.py | 4 +- homeassistant/components/mqtt/util.py | 48 +++++++++---------- 6 files changed, 75 insertions(+), 48 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 0b46ef0a32b..06921105aae 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -182,7 +182,7 @@ MQTT_PUBLISH_SCHEMA = vol.All( async def _async_setup_discovery( - hass: HomeAssistant, conf: ConfigType, config_entry + hass: HomeAssistant, conf: ConfigType, config_entry: ConfigEntry ) -> None: """Try to start the discovery of MQTT devices. @@ -377,7 +377,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: retain: bool = call.data[ATTR_RETAIN] if msg_topic_template is not None: try: - rendered_topic = template.Template( + rendered_topic: Any = template.Template( msg_topic_template, hass ).async_render(parse_result=False) msg_topic = valid_publish_topic(rendered_topic) @@ -620,12 +620,12 @@ def async_subscribe_connection_status( """Subscribe to MQTT connection changes.""" connection_status_callback_job = HassJob(connection_status_callback) - async def connected(): + async def connected() -> None: task = hass.async_run_hass_job(connection_status_callback_job, True) if task: await task - async def disconnected(): + async def disconnected() -> None: task = hass.async_run_hass_job(connection_status_callback_job, False) if task: await task @@ -636,7 +636,7 @@ def async_subscribe_connection_status( } @callback - def unsubscribe(): + def unsubscribe() -> None: subscriptions["connect"]() subscriptions["disconnect"]() diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 041e3cfa374..e909a378581 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -245,7 +245,7 @@ def subscribe( async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop ).result() - def remove(): + def remove() -> None: """Remove listener convert.""" run_callback_threadsafe(hass.loop, async_remove).result() @@ -341,7 +341,7 @@ class MQTT: self._ha_started = asyncio.Event() self._last_subscribe = time.time() self._mqttc: mqtt.Client = None - self._cleanup_on_unload: list[Callable] = [] + self._cleanup_on_unload: list[Callable[[], None]] = [] self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client self._pending_operations: dict[int, asyncio.Event] = {} @@ -352,14 +352,14 @@ class MQTT: else: @callback - def ha_started(_): + def ha_started(_: Event) -> None: self._ha_started.set() self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started) self.init_client() - async def async_stop_mqtt(_event: Event): + async def async_stop_mqtt(_event: Event) -> None: """Stop MQTT component.""" await self.async_disconnect() @@ -506,9 +506,11 @@ class MQTT: def _client_unsubscribe(topic: str) -> int: result: int | None = None + mid: int | None = None result, mid = self._mqttc.unsubscribe(topic) _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _raise_on_error(result) + assert mid return mid if any(other.topic == topic for other in self.subscriptions): @@ -553,7 +555,13 @@ class MQTT: if errors: _raise_on_errors(errors) - def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: + def _mqtt_on_connect( + self, + _mqttc: mqtt.Client, + _userdata: None, + _flags: dict[str, Any], + result_code: int, + ) -> None: """On connect callback. Resubscribe to all topics we were subscribed to and publish birth @@ -596,7 +604,7 @@ class MQTT: and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE] ): - async def publish_birth_message(birth_message): + async def publish_birth_message(birth_message: PublishMessage) -> None: await self._ha_started.wait() # Wait for Home Assistant to start await self._discovery_cooldown() # Wait for MQTT discovery to cool down await self.async_publish( @@ -611,7 +619,9 @@ class MQTT: publish_birth_message(birth_message), self.hass.loop ) - def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: + def _mqtt_on_message( + self, _mqttc: mqtt.Client, _userdata: None, msg: MQTTMessage + ) -> None: """Message received callback.""" self.hass.add_job(self._mqtt_handle_message, msg) @@ -663,7 +673,13 @@ class MQTT: ) self._mqtt_data.state_write_requests.process_write_state_requests() - def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None: + def _mqtt_on_callback( + self, + _mqttc: mqtt.Client, + _userdata: None, + mid: int, + _granted_qos: tuple[Any, ...] | None = None, + ) -> None: """Publish / Subscribe / Unsubscribe callback.""" self.hass.add_job(self._mqtt_handle_mid, mid) @@ -679,7 +695,9 @@ class MQTT: if mid not in self._pending_operations: self._pending_operations[mid] = asyncio.Event() - def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: + def _mqtt_on_disconnect( + self, _mqttc: mqtt.Client, _userdata: None, result_code: int + ) -> None: """Disconnected callback.""" self.connected = False dispatcher_send(self.hass, MQTT_DISCONNECTED) @@ -707,7 +725,7 @@ class MQTT: del self._pending_operations[mid] self._pending_operations_condition.notify_all() - async def _discovery_cooldown(self): + async def _discovery_cooldown(self) -> None: now = time.time() # Reset discovery and subscribe cooldowns self._mqtt_data.last_discovery = now diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index d94a2648918..ec818348701 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -287,7 +287,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): options_config: dict[str, Any] = {} bad_input: bool = False - def _birth_will(birt_or_will: str) -> dict: + def _birth_will(birt_or_will: str) -> dict[str, Any]: """Return the user input for birth or will.""" assert user_input return { @@ -298,8 +298,11 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): } def _validate( - field: str, values: dict[str, Any], error_code: str, schema: Callable - ): + field: str, + values: dict[str, Any], + error_code: str, + schema: Callable[[Any], Any], + ) -> None: """Validate the user input.""" nonlocal bad_input try: @@ -679,7 +682,9 @@ def try_connection( result: queue.Queue[bool] = queue.Queue(maxsize=1) - def on_connect(client_, userdata, flags, result_code): + def on_connect( + client_: mqtt.Client, userdata: None, flags: dict[str, Any], result_code: int + ) -> None: """Handle connection result.""" result.put(result_code == mqtt.CONNACK_ACCEPTED) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 7866e3cf6d6..a91a7fc7a88 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -287,6 +287,7 @@ async def async_get_platform_config_from_yaml( config_yaml: ConfigType | None = None, ) -> list[ConfigType]: """Return a list of validated configurations for the domain.""" + platform_configs: Any | None mqtt_data = get_mqtt_data(hass) if config_yaml is None: config_yaml = mqtt_data.config @@ -294,6 +295,7 @@ async def async_get_platform_config_from_yaml( return [] if not (platform_configs := config_yaml.get(platform_domain)): return [] + assert isinstance(platform_configs, list) return platform_configs @@ -662,7 +664,9 @@ def stop_discovery_updates( clear_discovery_hash(hass, discovery_hash) -async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: dict): +async def async_remove_discovery_payload( + hass: HomeAssistant, discovery_data: DiscoveryInfoType +) -> None: """Clear retained discovery topic in broker to avoid rediscovery after a restart of HA.""" discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC] await async_publish(hass, discovery_topic, "", retain=True) @@ -820,7 +824,7 @@ class MqttDiscoveryUpdate(Entity): """Initialize the discovery update mixin.""" self._discovery_data = discovery_data self._discovery_update = discovery_update - self._remove_discovery_updated: Callable | None = None + self._remove_discovery_updated: Callable[[], None] | None = None self._removed_from_hass = False if discovery_data is None: return @@ -1169,7 +1173,7 @@ def update_device( device_info = device_info_from_specifications(config[CONF_DEVICE]) if config_entry_id is not None and device_info is not None: - update_device_info = cast(dict, device_info) + update_device_info = cast(dict[str, Any], device_info) update_device_info["config_entry_id"] = config_entry_id device = device_registry.async_get_or_create(**update_device_info) diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 05f7f3934ee..87f5d3882bb 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -21,7 +21,7 @@ class EntitySubscription: hass: HomeAssistant = attr.ib() topic: str = attr.ib() message_callback: MessageCallbackType = attr.ib() - subscribe_task: Coroutine | None = attr.ib() + subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib() unsubscribe_callback: Callable[[], None] | None = attr.ib() qos: int = attr.ib(default=0) encoding: str = attr.ib(default="utf-8") @@ -53,7 +53,7 @@ class EntitySubscription: hass, self.topic, self.message_callback, self.qos, self.encoding ) - async def subscribe(self): + async def subscribe(self) -> None: """Subscribe to a topic.""" if not self.subscribe_task: return diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index 7e23b6c01f1..0b2d10977aa 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -40,58 +40,58 @@ def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: return not bool(hass.config_entries.async_entries(DOMAIN)[0].disabled_by) -def valid_topic(value: Any) -> str: +def valid_topic(topic: Any) -> str: """Validate that this is a valid topic name/filter.""" - value = cv.string(value) + validated_topic = cv.string(topic) try: - raw_value = value.encode("utf-8") + raw_validated_topic = validated_topic.encode("utf-8") except UnicodeError as err: raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") from err - if not raw_value: + if not raw_validated_topic: raise vol.Invalid("MQTT topic name/filter must not be empty.") - if len(raw_value) > 65535: + if len(raw_validated_topic) > 65535: raise vol.Invalid( "MQTT topic name/filter must not be longer than 65535 encoded bytes." ) - if "\0" in value: + if "\0" in validated_topic: raise vol.Invalid("MQTT topic name/filter must not contain null character.") - if any(char <= "\u001F" for char in value): + if any(char <= "\u001F" for char in validated_topic): raise vol.Invalid("MQTT topic name/filter must not contain control characters.") - if any("\u007f" <= char <= "\u009F" for char in value): + if any("\u007f" <= char <= "\u009F" for char in validated_topic): raise vol.Invalid("MQTT topic name/filter must not contain control characters.") - if any("\ufdd0" <= char <= "\ufdef" for char in value): + if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic): raise vol.Invalid("MQTT topic name/filter must not contain non-characters.") - if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in value): + if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic): raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.") - return value + return validated_topic -def valid_subscribe_topic(value: Any) -> str: +def valid_subscribe_topic(topic: Any) -> str: """Validate that we can subscribe using this MQTT topic.""" - value = valid_topic(value) - for i in (i for i, c in enumerate(value) if c == "+"): - if (i > 0 and value[i - 1] != "/") or ( - i < len(value) - 1 and value[i + 1] != "/" + validated_topic = valid_topic(topic) + for i in (i for i, c in enumerate(validated_topic) if c == "+"): + if (i > 0 and validated_topic[i - 1] != "/") or ( + i < len(validated_topic) - 1 and validated_topic[i + 1] != "/" ): raise vol.Invalid( "Single-level wildcard must occupy an entire level of the filter" ) - index = value.find("#") + index = validated_topic.find("#") if index != -1: - if index != len(value) - 1: + if index != len(validated_topic) - 1: # If there are multiple wildcards, this will also trigger raise vol.Invalid( "Multi-level wildcard must be the last " "character in the topic filter." ) - if len(value) > 1 and value[index - 1] != "/": + if len(validated_topic) > 1 and validated_topic[index - 1] != "/": raise vol.Invalid( "Multi-level wildcard must be after a topic level separator." ) - return value + return validated_topic def valid_subscribe_topic_template(value: Any) -> template.Template: @@ -104,12 +104,12 @@ def valid_subscribe_topic_template(value: Any) -> template.Template: return tpl -def valid_publish_topic(value: Any) -> str: +def valid_publish_topic(topic: Any) -> str: """Validate that we can publish using this MQTT topic.""" - value = valid_topic(value) - if "+" in value or "#" in value: + validated_topic = valid_topic(topic) + if "+" in validated_topic or "#" in validated_topic: raise vol.Invalid("Wildcards can not be used in topic names") - return value + return validated_topic _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))