diff --git a/homeassistant/components/modbus/__init__.py b/homeassistant/components/modbus/__init__.py index b9765f5e5ee..d4b70796c1a 100644 --- a/homeassistant/components/modbus/__init__.py +++ b/homeassistant/components/modbus/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -from typing import Any import voluptuous as vol @@ -99,71 +98,20 @@ from .const import ( DEFAULT_SCAN_INTERVAL, DEFAULT_STRUCTURE_PREFIX, DEFAULT_TEMP_UNIT, - MINIMUM_SCAN_INTERVAL, MODBUS_DOMAIN as DOMAIN, - PLATFORMS, ) from .modbus import async_modbus_setup -from .validators import sensor_schema_validator +from .validators import ( + number_validator, + scan_interval_validator, + sensor_schema_validator, +) _LOGGER = logging.getLogger(__name__) BASE_SCHEMA = vol.Schema({vol.Optional(CONF_NAME, default=DEFAULT_HUB): cv.string}) -def number(value: Any) -> int | float: - """Coerce a value to number without losing precision.""" - if isinstance(value, int): - return value - if isinstance(value, float): - return value - - try: - value = int(value) - return value - except (TypeError, ValueError): - pass - try: - value = float(value) - return value - except (TypeError, ValueError) as err: - raise vol.Invalid(f"invalid number {value}") from err - - -def control_scan_interval(config: dict) -> dict: - """Control scan_interval.""" - for hub in config: - minimum_scan_interval = DEFAULT_SCAN_INTERVAL - for component, conf_key in PLATFORMS: - if conf_key not in hub: - continue - - for entry in hub[conf_key]: - scan_interval = entry.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) - if scan_interval < MINIMUM_SCAN_INTERVAL: - if scan_interval == 0: - continue - _LOGGER.warning( - "%s %s scan_interval(%d) is adjusted to minimum(%d)", - component, - entry.get(CONF_NAME), - scan_interval, - MINIMUM_SCAN_INTERVAL, - ) - scan_interval = MINIMUM_SCAN_INTERVAL - entry[CONF_SCAN_INTERVAL] = scan_interval - minimum_scan_interval = min(scan_interval, minimum_scan_interval) - if CONF_TIMEOUT in hub and hub[CONF_TIMEOUT] > minimum_scan_interval - 1: - _LOGGER.warning( - "Modbus %s timeout(%d) is adjusted(%d) due to scan_interval", - hub.get(CONF_NAME, ""), - hub[CONF_TIMEOUT], - minimum_scan_interval - 1, - ) - hub[CONF_TIMEOUT] = minimum_scan_interval - 1 - return config - - BASE_COMPONENT_SCHEMA = vol.Schema( { vol.Required(CONF_NAME): cv.string, @@ -311,7 +259,7 @@ SENSOR_SCHEMA = BASE_COMPONENT_SCHEMA.extend( ] ), vol.Optional(CONF_DEVICE_CLASS): SENSOR_DEVICE_CLASSES_SCHEMA, - vol.Optional(CONF_OFFSET, default=0): number, + vol.Optional(CONF_OFFSET, default=0): number_validator, vol.Optional(CONF_PRECISION, default=0): cv.positive_int, vol.Optional(CONF_INPUT_TYPE, default=CALL_TYPE_REGISTER_HOLDING): vol.In( [CALL_TYPE_REGISTER_HOLDING, CALL_TYPE_REGISTER_INPUT] @@ -320,7 +268,7 @@ SENSOR_SCHEMA = BASE_COMPONENT_SCHEMA.extend( vol.Optional(CONF_SWAP, default=CONF_SWAP_NONE): vol.In( [CONF_SWAP_NONE, CONF_SWAP_BYTE, CONF_SWAP_WORD, CONF_SWAP_WORD_BYTE] ), - vol.Optional(CONF_SCALE, default=1): number, + vol.Optional(CONF_SCALE, default=1): number_validator, vol.Optional(CONF_STRUCTURE): cv.string, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, } @@ -380,7 +328,7 @@ CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.All( cv.ensure_list, - control_scan_interval, + scan_interval_validator, [ vol.Any(SERIAL_SCHEMA, ETHERNET_SCHEMA), ], diff --git a/homeassistant/components/modbus/validators.py b/homeassistant/components/modbus/validators.py index cd0c4524c74..0f376609de5 100644 --- a/homeassistant/components/modbus/validators.py +++ b/homeassistant/components/modbus/validators.py @@ -1,10 +1,19 @@ """Validate Modbus configuration.""" +from __future__ import annotations + import logging import struct +from typing import Any -from voluptuous import Invalid +import voluptuous as vol -from homeassistant.const import CONF_COUNT, CONF_NAME, CONF_STRUCTURE +from homeassistant.const import ( + CONF_COUNT, + CONF_NAME, + CONF_SCAN_INTERVAL, + CONF_STRUCTURE, + CONF_TIMEOUT, +) from .const import ( CONF_DATA_TYPE, @@ -15,7 +24,10 @@ from .const import ( CONF_SWAP_WORD, DATA_TYPE_CUSTOM, DATA_TYPE_STRING, + DEFAULT_SCAN_INTERVAL, DEFAULT_STRUCT_FORMAT, + MINIMUM_SCAN_INTERVAL, + PLATFORMS, ) _LOGGER = logging.getLogger(__name__) @@ -32,14 +44,14 @@ def sensor_schema_validator(config): f">{DEFAULT_STRUCT_FORMAT[config[CONF_DATA_TYPE]][config[CONF_COUNT]]}" ) except KeyError: - raise Invalid( + raise vol.Invalid( f"Unable to detect data type for {config[CONF_NAME]} sensor, try a custom type" ) from KeyError else: structure = config.get(CONF_STRUCTURE) if not structure: - raise Invalid( + raise vol.Invalid( f"Error in sensor {config[CONF_NAME]}. The `{CONF_STRUCTURE}` field can not be empty " f"if the parameter `{CONF_DATA_TYPE}` is set to the `{DATA_TYPE_CUSTOM}`" ) @@ -47,13 +59,13 @@ def sensor_schema_validator(config): try: size = struct.calcsize(structure) except struct.error as err: - raise Invalid( + raise vol.Invalid( f"Error in sensor {config[CONF_NAME]} structure: {str(err)}" ) from err bytecount = config[CONF_COUNT] * 2 if bytecount != size: - raise Invalid( + raise vol.Invalid( f"Structure request {size} bytes, " f"but {config[CONF_COUNT]} registers have a size of {bytecount} bytes" ) @@ -73,7 +85,7 @@ def sensor_schema_validator(config): else: # CONF_SWAP_WORD_BYTE, CONF_SWAP_WORD regs_needed = 2 if config[CONF_COUNT] < regs_needed or (config[CONF_COUNT] % regs_needed) != 0: - raise Invalid( + raise vol.Invalid( f"Error in sensor {config[CONF_NAME]} swap({swap_type}) " f"not possible due to the registers " f"count: {config[CONF_COUNT]}, needed: {regs_needed}" @@ -84,3 +96,56 @@ def sensor_schema_validator(config): CONF_STRUCTURE: structure, CONF_SWAP: swap_type, } + + +def number_validator(value: Any) -> int | float: + """Coerce a value to number without losing precision.""" + if isinstance(value, int): + return value + if isinstance(value, float): + return value + + try: + value = int(value) + return value + except (TypeError, ValueError): + pass + try: + value = float(value) + return value + except (TypeError, ValueError) as err: + raise vol.Invalid(f"invalid number {value}") from err + + +def scan_interval_validator(config: dict) -> dict: + """Control scan_interval.""" + for hub in config: + minimum_scan_interval = DEFAULT_SCAN_INTERVAL + for component, conf_key in PLATFORMS: + if conf_key not in hub: + continue + + for entry in hub[conf_key]: + scan_interval = entry.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) + if scan_interval < MINIMUM_SCAN_INTERVAL: + if scan_interval == 0: + continue + _LOGGER.warning( + "%s %s scan_interval(%d) is adjusted to minimum(%d)", + component, + entry.get(CONF_NAME), + scan_interval, + MINIMUM_SCAN_INTERVAL, + ) + scan_interval = MINIMUM_SCAN_INTERVAL + entry[CONF_SCAN_INTERVAL] = scan_interval + minimum_scan_interval = min(scan_interval, minimum_scan_interval) + if CONF_TIMEOUT in hub and hub[CONF_TIMEOUT] > minimum_scan_interval - 1: + _LOGGER.warning( + "Modbus %s timeout(%d) is adjusted(%d) due to scan_interval", + hub.get(CONF_NAME, ""), + hub[CONF_TIMEOUT], + minimum_scan_interval - 1, + ) + hub[CONF_TIMEOUT] = minimum_scan_interval - 1 + return config diff --git a/tests/components/modbus/test_init.py b/tests/components/modbus/test_init.py index 0819e5a3e89..8e45ee06976 100644 --- a/tests/components/modbus/test_init.py +++ b/tests/components/modbus/test_init.py @@ -19,7 +19,6 @@ import pytest import voluptuous as vol from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN -from homeassistant.components.modbus import number from homeassistant.components.modbus.const import ( ATTR_ADDRESS, ATTR_HUB, @@ -44,6 +43,7 @@ from homeassistant.components.modbus.const import ( SERVICE_WRITE_COIL, SERVICE_WRITE_REGISTER, ) +from homeassistant.components.modbus.validators import number_validator from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.const import ( CONF_ADDRESS, @@ -85,13 +85,13 @@ async def test_number_validator(): ("-15", int), ("-15.1", float), ]: - assert isinstance(number(value), value_type) + assert isinstance(number_validator(value), value_type) try: - number("x15.1") + number_validator("x15.1") except (vol.Invalid): return - pytest.fail("Number not throwing exception") + pytest.fail("Number_validator not throwing exception") @pytest.mark.parametrize(