Fix mysensors typing (#51518)

* Fix device

* Fix init

* Fix gateway

* Fix config flow

* Fix helpers

* Remove mysensors from typing ignore list
This commit is contained in:
Martin Hjelmare 2021-06-05 13:43:39 +02:00 committed by GitHub
parent 7a6d067eb4
commit e73cdfab2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 76 deletions

View file

@ -42,7 +42,7 @@ from .const import (
DevId, DevId,
SensorType, SensorType,
) )
from .device import MySensorsDevice, MySensorsEntity, get_mysensors_devices from .device import MySensorsDevice, get_mysensors_devices
from .gateway import finish_setup, get_mysensors_gateway, gw_stop, setup_gateway from .gateway import finish_setup, get_mysensors_gateway, gw_stop, setup_gateway
from .helpers import on_unload from .helpers import on_unload
@ -271,7 +271,7 @@ def setup_mysensors_platform(
hass: HomeAssistant, hass: HomeAssistant,
domain: str, # hass platform name domain: str, # hass platform name
discovery_info: dict[str, list[DevId]], discovery_info: dict[str, list[DevId]],
device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsEntity]], device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsDevice]],
device_args: ( device_args: (
None | tuple None | tuple
) = None, # extra arguments that will be given to the entity constructor ) = None, # extra arguments that will be given to the entity constructor
@ -302,11 +302,13 @@ def setup_mysensors_platform(
if not gateway: if not gateway:
_LOGGER.warning("Skipping setup of %s, no gateway found", dev_id) _LOGGER.warning("Skipping setup of %s, no gateway found", dev_id)
continue continue
device_class_copy = device_class
if isinstance(device_class, dict): if isinstance(device_class, dict):
child = gateway.sensors[node_id].children[child_id] child = gateway.sensors[node_id].children[child_id]
s_type = gateway.const.Presentation(child.type).name s_type = gateway.const.Presentation(child.type).name
device_class_copy = device_class[s_type] device_class_copy = device_class[s_type]
else:
device_class_copy = device_class
args_copy = (*device_args, gateway_id, gateway, node_id, child_id, value_type) args_copy = (*device_args, gateway_id, gateway, node_id, child_id, value_type)
devices[dev_id] = device_class_copy(*args_copy) devices[dev_id] = device_class_copy(*args_copy)

View file

@ -27,7 +27,7 @@ from homeassistant.components.mysensors import (
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.data_entry_flow import RESULT_TYPE_FORM, FlowResult from homeassistant.data_entry_flow import FlowResult
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
@ -111,7 +111,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Set up config flow.""" """Set up config flow."""
self._gw_type: str | None = None self._gw_type: str | None = None
async def async_step_import(self, user_input: dict[str, str] | None = None): async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
"""Import a config entry. """Import a config entry.
This method is called by async_setup and it has already This method is called by async_setup and it has already
@ -131,12 +131,14 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
else: else:
user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL
result: dict[str, Any] = await self.async_step_user(user_input=user_input) result: FlowResult = await self.async_step_user(user_input=user_input)
if result["type"] == RESULT_TYPE_FORM: if errors := result.get("errors"):
return self.async_abort(reason=next(iter(result["errors"].values()))) return self.async_abort(reason=next(iter(errors.values())))
return result return result
async def async_step_user(self, user_input: dict[str, str] | None = None): async def async_step_user(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Create a config entry from frontend user input.""" """Create a config entry from frontend user input."""
schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)} schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)}
schema = vol.Schema(schema) schema = vol.Schema(schema)
@ -158,9 +160,11 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form(step_id="user", data_schema=schema, errors=errors) return self.async_show_form(step_id="user", data_schema=schema, errors=errors)
async def async_step_gw_serial(self, user_input: dict[str, str] | None = None): async def async_step_gw_serial(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create config entry for a serial gateway.""" """Create config entry for a serial gateway."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
errors.update( errors.update(
await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input) await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input)
@ -187,7 +191,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
step_id="gw_serial", data_schema=schema, errors=errors step_id="gw_serial", data_schema=schema, errors=errors
) )
async def async_step_gw_tcp(self, user_input: dict[str, str] | None = None): async def async_step_gw_tcp(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create a config entry for a tcp gateway.""" """Create a config entry for a tcp gateway."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
@ -225,7 +231,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return True return True
return False return False
async def async_step_gw_mqtt(self, user_input: dict[str, str] | None = None): async def async_step_gw_mqtt(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create a config entry for a mqtt gateway.""" """Create a config entry for a mqtt gateway."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
@ -280,9 +288,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
) )
@callback @callback
def _async_create_entry( def _async_create_entry(self, user_input: dict[str, Any]) -> FlowResult:
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Create the config entry.""" """Create the config entry."""
return self.async_create_entry( return self.async_create_entry(
title=f"{user_input[CONF_DEVICE]}", title=f"{user_input[CONF_DEVICE]}",
@ -296,55 +302,52 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self, self,
gw_type: ConfGatewayType, gw_type: ConfGatewayType,
errors: dict[str, str], errors: dict[str, str],
user_input: dict[str, str] | None = None, user_input: dict[str, Any],
) -> dict[str, str]: ) -> dict[str, str]:
"""Validate parameters common to all gateway types.""" """Validate parameters common to all gateway types."""
if user_input is not None: errors.update(_validate_version(user_input[CONF_VERSION]))
errors.update(_validate_version(user_input.get(CONF_VERSION)))
if gw_type != CONF_GATEWAY_TYPE_MQTT: if gw_type != CONF_GATEWAY_TYPE_MQTT:
if gw_type == CONF_GATEWAY_TYPE_TCP: if gw_type == CONF_GATEWAY_TYPE_TCP:
verification_func = is_socket_address verification_func = is_socket_address
else: else:
verification_func = is_serial_port verification_func = is_serial_port
try: try:
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
verification_func, user_input.get(CONF_DEVICE) verification_func, user_input.get(CONF_DEVICE)
) )
except vol.Invalid: except vol.Invalid:
errors[CONF_DEVICE] = ( errors[CONF_DEVICE] = (
"invalid_ip" "invalid_ip"
if gw_type == CONF_GATEWAY_TYPE_TCP if gw_type == CONF_GATEWAY_TYPE_TCP
else "invalid_serial" else "invalid_serial"
) )
if CONF_PERSISTENCE_FILE in user_input: if CONF_PERSISTENCE_FILE in user_input:
try: try:
is_persistence_file(user_input[CONF_PERSISTENCE_FILE]) is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
except vol.Invalid: except vol.Invalid:
errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file" errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
else: else:
real_persistence_path = user_input[ real_persistence_path = user_input[
CONF_PERSISTENCE_FILE CONF_PERSISTENCE_FILE
] = self._normalize_persistence_file( ] = self._normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE])
user_input[CONF_PERSISTENCE_FILE] for other_entry in self._async_current_entries():
) if CONF_PERSISTENCE_FILE not in other_entry.data:
for other_entry in self._async_current_entries(): continue
if CONF_PERSISTENCE_FILE not in other_entry.data: if real_persistence_path == self._normalize_persistence_file(
continue other_entry.data[CONF_PERSISTENCE_FILE]
if real_persistence_path == self._normalize_persistence_file( ):
other_entry.data[CONF_PERSISTENCE_FILE] errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
): break
errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
break
for other_entry in self._async_current_entries(): for other_entry in self._async_current_entries():
if _is_same_device(gw_type, user_input, other_entry): if _is_same_device(gw_type, user_input, other_entry):
errors["base"] = "already_configured" errors["base"] = "already_configured"
break break
# if no errors so far, try to connect # if no errors so far, try to connect
if not errors and not await try_connect(self.hass, user_input): if not errors and not await try_connect(self.hass, user_input):
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
return errors return errors

View file

@ -3,12 +3,13 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
from typing import Any
from mysensors import BaseAsyncGateway, Sensor from mysensors import BaseAsyncGateway, Sensor
from mysensors.sensor import ChildSensor from mysensors.sensor import ChildSensor
from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_OFF, STATE_ON from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_OFF, STATE_ON
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo, Entity from homeassistant.helpers.entity import DeviceInfo, Entity
@ -36,6 +37,8 @@ MYSENSORS_PLATFORM_DEVICES = "mysensors_devices_{}"
class MySensorsDevice: class MySensorsDevice:
"""Representation of a MySensors device.""" """Representation of a MySensors device."""
hass: HomeAssistant
def __init__( def __init__(
self, self,
gateway_id: GatewayId, gateway_id: GatewayId,
@ -51,9 +54,8 @@ class MySensorsDevice:
self.child_id: int = child_id self.child_id: int = child_id
self.value_type: int = value_type # value_type as int. string variant can be looked up in gateway consts self.value_type: int = value_type # value_type as int. string variant can be looked up in gateway consts
self.child_type = self._child.type self.child_type = self._child.type
self._values = {} self._values: dict[int, Any] = {}
self._update_scheduled = False self._update_scheduled = False
self.hass = None
@property @property
def dev_id(self) -> DevId: def dev_id(self) -> DevId:

View file

@ -66,7 +66,7 @@ def is_socket_address(value):
raise vol.Invalid("Device is not a valid domain name or ip address") from err raise vol.Invalid("Device is not a valid domain name or ip address") from err
async def try_connect(hass: HomeAssistant, user_input: dict[str, str]) -> bool: async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool:
"""Try to connect to a gateway and report if it worked.""" """Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT: if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :( return True # dont validate mqtt. mqtt gateways dont send ready messages :(
@ -250,7 +250,6 @@ async def _discover_persistent_devices(
hass: HomeAssistant, entry: ConfigEntry, gateway: BaseAsyncGateway hass: HomeAssistant, entry: ConfigEntry, gateway: BaseAsyncGateway
): ):
"""Discover platforms for devices loaded via persistence file.""" """Discover platforms for devices loaded via persistence file."""
tasks = []
new_devices = defaultdict(list) new_devices = defaultdict(list)
for node_id in gateway.sensors: for node_id in gateway.sensors:
if not validate_node(gateway, node_id): if not validate_node(gateway, node_id):
@ -263,8 +262,6 @@ async def _discover_persistent_devices(
_LOGGER.debug("discovering persistent devices: %s", new_devices) _LOGGER.debug("discovering persistent devices: %s", new_devices)
for platform, dev_ids in new_devices.items(): for platform, dev_ids in new_devices.items():
discover_mysensors_platform(hass, entry.entry_id, platform, dev_ids) discover_mysensors_platform(hass, entry.entry_id, platform, dev_ids)
if tasks:
await asyncio.wait(tasks)
async def gw_stop(hass, entry: ConfigEntry, gateway: BaseAsyncGateway): async def gw_stop(hass, entry: ConfigEntry, gateway: BaseAsyncGateway):
@ -331,8 +328,8 @@ def _gw_callback_factory(
msg_type = msg.gateway.const.MessageType(msg.type) msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[ msg_handler: Callable[
[Any, GatewayId, Message], Coroutine[None] [HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
] = HANDLERS.get(msg_type.name) ] | None = HANDLERS.get(msg_type.name)
if msg_handler is None: if msg_handler is None:
return return

View file

@ -176,11 +176,15 @@ def validate_child(
) -> defaultdict[str, list[DevId]]: ) -> defaultdict[str, list[DevId]]:
"""Validate a child. Returns a dict mapping hass platform names to list of DevId.""" """Validate a child. Returns a dict mapping hass platform names to list of DevId."""
validated: defaultdict[str, list[DevId]] = defaultdict(list) validated: defaultdict[str, list[DevId]] = defaultdict(list)
pres: IntEnum = gateway.const.Presentation pres: type[IntEnum] = gateway.const.Presentation
set_req: IntEnum = gateway.const.SetReq set_req: type[IntEnum] = gateway.const.SetReq
child_type_name: SensorType | None = next( child_type_name: SensorType | None = next(
(member.name for member in pres if member.value == child.type), None (member.name for member in pres if member.value == child.type), None
) )
if not child_type_name:
_LOGGER.warning("Child type %s is not supported", child.type)
return validated
value_types: set[int] = {value_type} if value_type else {*child.values} value_types: set[int] = {value_type} if value_type else {*child.values}
value_type_names: set[ValueType] = { value_type_names: set[ValueType] = {
member.name for member in set_req if member.value in value_types member.name for member in set_req if member.value in value_types
@ -199,7 +203,7 @@ def validate_child(
child_value_names: set[ValueType] = { child_value_names: set[ValueType] = {
member.name for member in set_req if member.value in child.values member.name for member in set_req if member.value in child.values
} }
v_names: set[ValueType] = platform_v_names & child_value_names v_names = platform_v_names & child_value_names
for v_name in v_names: for v_name in v_names:
child_schema_gen = SCHEMAS.get((platform, v_name), default_schema) child_schema_gen = SCHEMAS.get((platform, v_name), default_schema)

View file

@ -1197,9 +1197,6 @@ ignore_errors = true
[mypy-homeassistant.components.mullvad.*] [mypy-homeassistant.components.mullvad.*]
ignore_errors = true ignore_errors = true
[mypy-homeassistant.components.mysensors.*]
ignore_errors = true
[mypy-homeassistant.components.neato.*] [mypy-homeassistant.components.neato.*]
ignore_errors = true ignore_errors = true

View file

@ -127,7 +127,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.motion_blinds.*", "homeassistant.components.motion_blinds.*",
"homeassistant.components.mqtt.*", "homeassistant.components.mqtt.*",
"homeassistant.components.mullvad.*", "homeassistant.components.mullvad.*",
"homeassistant.components.mysensors.*",
"homeassistant.components.neato.*", "homeassistant.components.neato.*",
"homeassistant.components.ness_alarm.*", "homeassistant.components.ness_alarm.*",
"homeassistant.components.nest.*", "homeassistant.components.nest.*",