From 51c8b1eb0b2d0e07610bc753883ae2873fa5857e Mon Sep 17 00:00:00 2001 From: Michael <35783820+mib1185@users.noreply.github.com> Date: Mon, 24 May 2021 12:03:43 +0200 Subject: [PATCH] Improve code quality of TCP platform (#51000) * Improve code placements * Fix entity inheritance * fix tests * Improve PLATFORM_SCHEMA handling * Apply suggestions --- homeassistant/components/tcp/binary_sensor.py | 11 +- homeassistant/components/tcp/common.py | 158 ++++++++++++++++++ homeassistant/components/tcp/sensor.py | 150 +---------------- tests/components/tcp/test_binary_sensor.py | 4 +- tests/components/tcp/test_sensor.py | 8 +- 5 files changed, 175 insertions(+), 156 deletions(-) create mode 100644 homeassistant/components/tcp/common.py diff --git a/homeassistant/components/tcp/binary_sensor.py b/homeassistant/components/tcp/binary_sensor.py index c0e53fba334..ee26ff74a7f 100644 --- a/homeassistant/components/tcp/binary_sensor.py +++ b/homeassistant/components/tcp/binary_sensor.py @@ -3,15 +3,18 @@ from __future__ import annotations from typing import Any, Final -from homeassistant.components.binary_sensor import BinarySensorEntity +from homeassistant.components.binary_sensor import ( + PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA, + BinarySensorEntity, +) from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType +from .common import TCP_PLATFORM_SCHEMA, TcpEntity from .const import CONF_VALUE_ON -from .sensor import PLATFORM_SCHEMA as TCP_PLATFORM_SCHEMA, TcpSensor -PLATFORM_SCHEMA: Final = TCP_PLATFORM_SCHEMA +PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend(TCP_PLATFORM_SCHEMA) def setup_platform( @@ -24,7 +27,7 @@ def setup_platform( add_entities([TcpBinarySensor(hass, config)]) -class TcpBinarySensor(BinarySensorEntity, TcpSensor): +class TcpBinarySensor(TcpEntity, BinarySensorEntity): """A binary sensor which is on when its state == CONF_VALUE_ON.""" @property diff --git a/homeassistant/components/tcp/common.py b/homeassistant/components/tcp/common.py new file mode 100644 index 00000000000..d2d40358970 --- /dev/null +++ b/homeassistant/components/tcp/common.py @@ -0,0 +1,158 @@ +"""Common code for TCP component.""" +from __future__ import annotations + +import logging +import select +import socket +import ssl +from typing import Any, Final + +import voluptuous as vol + +from homeassistant.const import ( + CONF_HOST, + CONF_NAME, + CONF_PAYLOAD, + CONF_PORT, + CONF_SSL, + CONF_TIMEOUT, + CONF_UNIT_OF_MEASUREMENT, + CONF_VALUE_TEMPLATE, + CONF_VERIFY_SSL, +) +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.template import Template +from homeassistant.helpers.typing import ConfigType + +from .const import ( + CONF_BUFFER_SIZE, + CONF_VALUE_ON, + DEFAULT_BUFFER_SIZE, + DEFAULT_NAME, + DEFAULT_SSL, + DEFAULT_TIMEOUT, + DEFAULT_VERIFY_SSL, +) +from .model import TcpSensorConfig + +_LOGGER: Final = logging.getLogger(__name__) + + +TCP_PLATFORM_SCHEMA: Final[dict[vol.Marker, Any]] = { + vol.Required(CONF_HOST): cv.string, + vol.Required(CONF_PORT): cv.port, + vol.Required(CONF_PAYLOAD): cv.string, + vol.Optional(CONF_BUFFER_SIZE, default=DEFAULT_BUFFER_SIZE): cv.positive_int, + vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, + vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, + vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, + vol.Optional(CONF_VALUE_ON): cv.string, + vol.Optional(CONF_VALUE_TEMPLATE): cv.template, + vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, + vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, +} + + +class TcpEntity(Entity): + """Base entity class for TCP platform.""" + + def __init__(self, hass: HomeAssistant, config: ConfigType) -> None: + """Set all the config values if they exist and get initial state.""" + + value_template: Template | None = config.get(CONF_VALUE_TEMPLATE) + if value_template is not None: + value_template.hass = hass + + self._hass = hass + self._config: TcpSensorConfig = { + CONF_NAME: config[CONF_NAME], + CONF_HOST: config[CONF_HOST], + CONF_PORT: config[CONF_PORT], + CONF_TIMEOUT: config[CONF_TIMEOUT], + CONF_PAYLOAD: config[CONF_PAYLOAD], + CONF_UNIT_OF_MEASUREMENT: config.get(CONF_UNIT_OF_MEASUREMENT), + CONF_VALUE_TEMPLATE: value_template, + CONF_VALUE_ON: config.get(CONF_VALUE_ON), + CONF_BUFFER_SIZE: config[CONF_BUFFER_SIZE], + CONF_SSL: config[CONF_SSL], + CONF_VERIFY_SSL: config[CONF_VERIFY_SSL], + } + + self._ssl_context: ssl.SSLContext | None = None + if self._config[CONF_SSL]: + self._ssl_context = ssl.create_default_context() + if not self._config[CONF_VERIFY_SSL]: + self._ssl_context.check_hostname = False + self._ssl_context.verify_mode = ssl.CERT_NONE + + self._state: str | None = None + self.update() + + @property + def name(self) -> str: + """Return the name of this sensor.""" + return self._config[CONF_NAME] + + def update(self) -> None: + """Get the latest value for this sensor.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(self._config[CONF_TIMEOUT]) + try: + sock.connect((self._config[CONF_HOST], self._config[CONF_PORT])) + except OSError as err: + _LOGGER.error( + "Unable to connect to %s on port %s: %s", + self._config[CONF_HOST], + self._config[CONF_PORT], + err, + ) + return + + if self._ssl_context is not None: + sock = self._ssl_context.wrap_socket( + sock, server_hostname=self._config[CONF_HOST] + ) + + try: + sock.send(self._config[CONF_PAYLOAD].encode()) + except OSError as err: + _LOGGER.error( + "Unable to send payload %r to %s on port %s: %s", + self._config[CONF_PAYLOAD], + self._config[CONF_HOST], + self._config[CONF_PORT], + err, + ) + return + + readable, _, _ = select.select([sock], [], [], self._config[CONF_TIMEOUT]) + if not readable: + _LOGGER.warning( + "Timeout (%s second(s)) waiting for a response after " + "sending %r to %s on port %s", + self._config[CONF_TIMEOUT], + self._config[CONF_PAYLOAD], + self._config[CONF_HOST], + self._config[CONF_PORT], + ) + return + + value = sock.recv(self._config[CONF_BUFFER_SIZE]).decode() + + value_template = self._config[CONF_VALUE_TEMPLATE] + if value_template is not None: + try: + self._state = value_template.render(parse_result=False, value=value) + return + except TemplateError: + _LOGGER.error( + "Unable to render template of %r with value: %r", + self._config[CONF_VALUE_TEMPLATE], + value, + ) + return + + self._state = value diff --git a/homeassistant/components/tcp/sensor.py b/homeassistant/components/tcp/sensor.py index ff5db39bba7..d282974fd4c 100644 --- a/homeassistant/components/tcp/sensor.py +++ b/homeassistant/components/tcp/sensor.py @@ -1,64 +1,20 @@ """Support for TCP socket based sensors.""" from __future__ import annotations -import logging -import select -import socket -import ssl from typing import Any, Final -import voluptuous as vol - from homeassistant.components.sensor import ( PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA, SensorEntity, ) -from homeassistant.const import ( - CONF_HOST, - CONF_NAME, - CONF_PAYLOAD, - CONF_PORT, - CONF_SSL, - CONF_TIMEOUT, - CONF_UNIT_OF_MEASUREMENT, - CONF_VALUE_TEMPLATE, - CONF_VERIFY_SSL, -) +from homeassistant.const import CONF_UNIT_OF_MEASUREMENT from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError -import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, StateType -from .const import ( - CONF_BUFFER_SIZE, - CONF_VALUE_ON, - DEFAULT_BUFFER_SIZE, - DEFAULT_NAME, - DEFAULT_SSL, - DEFAULT_TIMEOUT, - DEFAULT_VERIFY_SSL, -) -from .model import TcpSensorConfig +from .common import TCP_PLATFORM_SCHEMA, TcpEntity -_LOGGER: Final = logging.getLogger(__name__) - -PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_HOST): cv.string, - vol.Required(CONF_PORT): cv.port, - vol.Required(CONF_PAYLOAD): cv.string, - vol.Optional(CONF_BUFFER_SIZE, default=DEFAULT_BUFFER_SIZE): cv.positive_int, - vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, - vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, - vol.Optional(CONF_VALUE_ON): cv.string, - vol.Optional(CONF_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, - vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, - } -) +PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend(TCP_PLATFORM_SCHEMA) def setup_platform( @@ -71,46 +27,9 @@ def setup_platform( add_entities([TcpSensor(hass, config)]) -class TcpSensor(SensorEntity): +class TcpSensor(TcpEntity, SensorEntity): """Implementation of a TCP socket based sensor.""" - def __init__(self, hass: HomeAssistant, config: ConfigType) -> None: - """Set all the config values if they exist and get initial state.""" - - value_template: Template | None = config.get(CONF_VALUE_TEMPLATE) - if value_template is not None: - value_template.hass = hass - - self._hass = hass - self._config: TcpSensorConfig = { - CONF_NAME: config[CONF_NAME], - CONF_HOST: config[CONF_HOST], - CONF_PORT: config[CONF_PORT], - CONF_TIMEOUT: config[CONF_TIMEOUT], - CONF_PAYLOAD: config[CONF_PAYLOAD], - CONF_UNIT_OF_MEASUREMENT: config.get(CONF_UNIT_OF_MEASUREMENT), - CONF_VALUE_TEMPLATE: value_template, - CONF_VALUE_ON: config.get(CONF_VALUE_ON), - CONF_BUFFER_SIZE: config[CONF_BUFFER_SIZE], - CONF_SSL: config[CONF_SSL], - CONF_VERIFY_SSL: config[CONF_VERIFY_SSL], - } - - self._ssl_context: ssl.SSLContext | None = None - if self._config[CONF_SSL]: - self._ssl_context = ssl.create_default_context() - if not self._config[CONF_VERIFY_SSL]: - self._ssl_context.check_hostname = False - self._ssl_context.verify_mode = ssl.CERT_NONE - - self._state: str | None = None - self.update() - - @property - def name(self) -> str: - """Return the name of this sensor.""" - return self._config[CONF_NAME] - @property def state(self) -> StateType: """Return the state of the device.""" @@ -120,64 +39,3 @@ class TcpSensor(SensorEntity): def unit_of_measurement(self) -> str | None: """Return the unit of measurement of this entity.""" return self._config[CONF_UNIT_OF_MEASUREMENT] - - def update(self) -> None: - """Get the latest value for this sensor.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.settimeout(self._config[CONF_TIMEOUT]) - try: - sock.connect((self._config[CONF_HOST], self._config[CONF_PORT])) - except OSError as err: - _LOGGER.error( - "Unable to connect to %s on port %s: %s", - self._config[CONF_HOST], - self._config[CONF_PORT], - err, - ) - return - - if self._ssl_context is not None: - sock = self._ssl_context.wrap_socket( - sock, server_hostname=self._config[CONF_HOST] - ) - - try: - sock.send(self._config[CONF_PAYLOAD].encode()) - except OSError as err: - _LOGGER.error( - "Unable to send payload %r to %s on port %s: %s", - self._config[CONF_PAYLOAD], - self._config[CONF_HOST], - self._config[CONF_PORT], - err, - ) - return - - readable, _, _ = select.select([sock], [], [], self._config[CONF_TIMEOUT]) - if not readable: - _LOGGER.warning( - "Timeout (%s second(s)) waiting for a response after " - "sending %r to %s on port %s", - self._config[CONF_TIMEOUT], - self._config[CONF_PAYLOAD], - self._config[CONF_HOST], - self._config[CONF_PORT], - ) - return - - value = sock.recv(self._config[CONF_BUFFER_SIZE]).decode() - - value_template = self._config[CONF_VALUE_TEMPLATE] - if value_template is not None: - try: - self._state = value_template.render(parse_result=False, value=value) - return - except TemplateError: - _LOGGER.error( - "Unable to render template of %r with value: %r", - self._config[CONF_VALUE_TEMPLATE], - value, - ) - return - - self._state = value diff --git a/tests/components/tcp/test_binary_sensor.py b/tests/components/tcp/test_binary_sensor.py index 21dd84b1892..f8c13b41c30 100644 --- a/tests/components/tcp/test_binary_sensor.py +++ b/tests/components/tcp/test_binary_sensor.py @@ -20,9 +20,9 @@ TEST_ENTITY = "binary_sensor.test_name" def mock_socket_fixture(): """Mock the socket.""" with patch( - "homeassistant.components.tcp.sensor.socket.socket" + "homeassistant.components.tcp.common.socket.socket" ) as mock_socket, patch( - "homeassistant.components.tcp.sensor.select.select", + "homeassistant.components.tcp.common.select.select", return_value=(True, False, False), ): # yield the return value of the socket context manager diff --git a/tests/components/tcp/test_sensor.py b/tests/components/tcp/test_sensor.py index 48b5703c204..46db3367677 100644 --- a/tests/components/tcp/test_sensor.py +++ b/tests/components/tcp/test_sensor.py @@ -4,7 +4,7 @@ from unittest.mock import call, patch import pytest -import homeassistant.components.tcp.sensor as tcp +import homeassistant.components.tcp.common as tcp from homeassistant.setup import async_setup_component from tests.common import assert_setup_component @@ -41,7 +41,7 @@ socket_test_value = "value" @pytest.fixture(name="mock_socket") def mock_socket_fixture(mock_select): """Mock socket.""" - with patch("homeassistant.components.tcp.sensor.socket.socket") as mock_socket: + with patch("homeassistant.components.tcp.common.socket.socket") as mock_socket: socket_instance = mock_socket.return_value.__enter__.return_value socket_instance.recv.return_value = socket_test_value.encode() yield socket_instance @@ -51,7 +51,7 @@ def mock_socket_fixture(mock_select): def mock_select_fixture(): """Mock select.""" with patch( - "homeassistant.components.tcp.sensor.select.select", + "homeassistant.components.tcp.common.select.select", return_value=(True, False, False), ) as mock_select: yield mock_select @@ -61,7 +61,7 @@ def mock_select_fixture(): def mock_ssl_context_fixture(): """Mock select.""" with patch( - "homeassistant.components.tcp.sensor.ssl.create_default_context", + "homeassistant.components.tcp.common.ssl.create_default_context", ) as mock_ssl_context: mock_ssl_context.return_value.wrap_socket.return_value.recv.return_value = ( socket_test_value + "_ssl"