Improve code quality of TCP platform (#51000)

* Improve code placements

* Fix entity inheritance

* fix tests

* Improve PLATFORM_SCHEMA handling

* Apply suggestions
This commit is contained in:
Michael 2021-05-24 12:03:43 +02:00 committed by GitHub
parent 870c61a622
commit 51c8b1eb0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 175 additions and 156 deletions

View file

@ -3,15 +3,18 @@ from __future__ import annotations
from typing import Any, Final
from homeassistant.components.binary_sensor import BinarySensorEntity
from homeassistant.components.binary_sensor import (
PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA,
BinarySensorEntity,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType
from .common import TCP_PLATFORM_SCHEMA, TcpEntity
from .const import CONF_VALUE_ON
from .sensor import PLATFORM_SCHEMA as TCP_PLATFORM_SCHEMA, TcpSensor
PLATFORM_SCHEMA: Final = TCP_PLATFORM_SCHEMA
PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend(TCP_PLATFORM_SCHEMA)
def setup_platform(
@ -24,7 +27,7 @@ def setup_platform(
add_entities([TcpBinarySensor(hass, config)])
class TcpBinarySensor(BinarySensorEntity, TcpSensor):
class TcpBinarySensor(TcpEntity, BinarySensorEntity):
"""A binary sensor which is on when its state == CONF_VALUE_ON."""
@property

View file

@ -0,0 +1,158 @@
"""Common code for TCP component."""
from __future__ import annotations
import logging
import select
import socket
import ssl
from typing import Any, Final
import voluptuous as vol
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
CONF_PAYLOAD,
CONF_PORT,
CONF_SSL,
CONF_TIMEOUT,
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
CONF_VERIFY_SSL,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_BUFFER_SIZE,
CONF_VALUE_ON,
DEFAULT_BUFFER_SIZE,
DEFAULT_NAME,
DEFAULT_SSL,
DEFAULT_TIMEOUT,
DEFAULT_VERIFY_SSL,
)
from .model import TcpSensorConfig
_LOGGER: Final = logging.getLogger(__name__)
TCP_PLATFORM_SCHEMA: Final[dict[vol.Marker, Any]] = {
vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_PORT): cv.port,
vol.Required(CONF_PAYLOAD): cv.string,
vol.Optional(CONF_BUFFER_SIZE, default=DEFAULT_BUFFER_SIZE): cv.positive_int,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
vol.Optional(CONF_VALUE_ON): cv.string,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
}
class TcpEntity(Entity):
"""Base entity class for TCP platform."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Set all the config values if they exist and get initial state."""
value_template: Template | None = config.get(CONF_VALUE_TEMPLATE)
if value_template is not None:
value_template.hass = hass
self._hass = hass
self._config: TcpSensorConfig = {
CONF_NAME: config[CONF_NAME],
CONF_HOST: config[CONF_HOST],
CONF_PORT: config[CONF_PORT],
CONF_TIMEOUT: config[CONF_TIMEOUT],
CONF_PAYLOAD: config[CONF_PAYLOAD],
CONF_UNIT_OF_MEASUREMENT: config.get(CONF_UNIT_OF_MEASUREMENT),
CONF_VALUE_TEMPLATE: value_template,
CONF_VALUE_ON: config.get(CONF_VALUE_ON),
CONF_BUFFER_SIZE: config[CONF_BUFFER_SIZE],
CONF_SSL: config[CONF_SSL],
CONF_VERIFY_SSL: config[CONF_VERIFY_SSL],
}
self._ssl_context: ssl.SSLContext | None = None
if self._config[CONF_SSL]:
self._ssl_context = ssl.create_default_context()
if not self._config[CONF_VERIFY_SSL]:
self._ssl_context.check_hostname = False
self._ssl_context.verify_mode = ssl.CERT_NONE
self._state: str | None = None
self.update()
@property
def name(self) -> str:
"""Return the name of this sensor."""
return self._config[CONF_NAME]
def update(self) -> None:
"""Get the latest value for this sensor."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(self._config[CONF_TIMEOUT])
try:
sock.connect((self._config[CONF_HOST], self._config[CONF_PORT]))
except OSError as err:
_LOGGER.error(
"Unable to connect to %s on port %s: %s",
self._config[CONF_HOST],
self._config[CONF_PORT],
err,
)
return
if self._ssl_context is not None:
sock = self._ssl_context.wrap_socket(
sock, server_hostname=self._config[CONF_HOST]
)
try:
sock.send(self._config[CONF_PAYLOAD].encode())
except OSError as err:
_LOGGER.error(
"Unable to send payload %r to %s on port %s: %s",
self._config[CONF_PAYLOAD],
self._config[CONF_HOST],
self._config[CONF_PORT],
err,
)
return
readable, _, _ = select.select([sock], [], [], self._config[CONF_TIMEOUT])
if not readable:
_LOGGER.warning(
"Timeout (%s second(s)) waiting for a response after "
"sending %r to %s on port %s",
self._config[CONF_TIMEOUT],
self._config[CONF_PAYLOAD],
self._config[CONF_HOST],
self._config[CONF_PORT],
)
return
value = sock.recv(self._config[CONF_BUFFER_SIZE]).decode()
value_template = self._config[CONF_VALUE_TEMPLATE]
if value_template is not None:
try:
self._state = value_template.render(parse_result=False, value=value)
return
except TemplateError:
_LOGGER.error(
"Unable to render template of %r with value: %r",
self._config[CONF_VALUE_TEMPLATE],
value,
)
return
self._state = value

View file

@ -1,64 +1,20 @@
"""Support for TCP socket based sensors."""
from __future__ import annotations
import logging
import select
import socket
import ssl
from typing import Any, Final
import voluptuous as vol
from homeassistant.components.sensor import (
PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA,
SensorEntity,
)
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
CONF_PAYLOAD,
CONF_PORT,
CONF_SSL,
CONF_TIMEOUT,
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
CONF_VERIFY_SSL,
)
from homeassistant.const import CONF_UNIT_OF_MEASUREMENT
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, StateType
from .const import (
CONF_BUFFER_SIZE,
CONF_VALUE_ON,
DEFAULT_BUFFER_SIZE,
DEFAULT_NAME,
DEFAULT_SSL,
DEFAULT_TIMEOUT,
DEFAULT_VERIFY_SSL,
)
from .model import TcpSensorConfig
from .common import TCP_PLATFORM_SCHEMA, TcpEntity
_LOGGER: Final = logging.getLogger(__name__)
PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_PORT): cv.port,
vol.Required(CONF_PAYLOAD): cv.string,
vol.Optional(CONF_BUFFER_SIZE, default=DEFAULT_BUFFER_SIZE): cv.positive_int,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
vol.Optional(CONF_VALUE_ON): cv.string,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
}
)
PLATFORM_SCHEMA: Final = PARENT_PLATFORM_SCHEMA.extend(TCP_PLATFORM_SCHEMA)
def setup_platform(
@ -71,46 +27,9 @@ def setup_platform(
add_entities([TcpSensor(hass, config)])
class TcpSensor(SensorEntity):
class TcpSensor(TcpEntity, SensorEntity):
"""Implementation of a TCP socket based sensor."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Set all the config values if they exist and get initial state."""
value_template: Template | None = config.get(CONF_VALUE_TEMPLATE)
if value_template is not None:
value_template.hass = hass
self._hass = hass
self._config: TcpSensorConfig = {
CONF_NAME: config[CONF_NAME],
CONF_HOST: config[CONF_HOST],
CONF_PORT: config[CONF_PORT],
CONF_TIMEOUT: config[CONF_TIMEOUT],
CONF_PAYLOAD: config[CONF_PAYLOAD],
CONF_UNIT_OF_MEASUREMENT: config.get(CONF_UNIT_OF_MEASUREMENT),
CONF_VALUE_TEMPLATE: value_template,
CONF_VALUE_ON: config.get(CONF_VALUE_ON),
CONF_BUFFER_SIZE: config[CONF_BUFFER_SIZE],
CONF_SSL: config[CONF_SSL],
CONF_VERIFY_SSL: config[CONF_VERIFY_SSL],
}
self._ssl_context: ssl.SSLContext | None = None
if self._config[CONF_SSL]:
self._ssl_context = ssl.create_default_context()
if not self._config[CONF_VERIFY_SSL]:
self._ssl_context.check_hostname = False
self._ssl_context.verify_mode = ssl.CERT_NONE
self._state: str | None = None
self.update()
@property
def name(self) -> str:
"""Return the name of this sensor."""
return self._config[CONF_NAME]
@property
def state(self) -> StateType:
"""Return the state of the device."""
@ -120,64 +39,3 @@ class TcpSensor(SensorEntity):
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of this entity."""
return self._config[CONF_UNIT_OF_MEASUREMENT]
def update(self) -> None:
"""Get the latest value for this sensor."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(self._config[CONF_TIMEOUT])
try:
sock.connect((self._config[CONF_HOST], self._config[CONF_PORT]))
except OSError as err:
_LOGGER.error(
"Unable to connect to %s on port %s: %s",
self._config[CONF_HOST],
self._config[CONF_PORT],
err,
)
return
if self._ssl_context is not None:
sock = self._ssl_context.wrap_socket(
sock, server_hostname=self._config[CONF_HOST]
)
try:
sock.send(self._config[CONF_PAYLOAD].encode())
except OSError as err:
_LOGGER.error(
"Unable to send payload %r to %s on port %s: %s",
self._config[CONF_PAYLOAD],
self._config[CONF_HOST],
self._config[CONF_PORT],
err,
)
return
readable, _, _ = select.select([sock], [], [], self._config[CONF_TIMEOUT])
if not readable:
_LOGGER.warning(
"Timeout (%s second(s)) waiting for a response after "
"sending %r to %s on port %s",
self._config[CONF_TIMEOUT],
self._config[CONF_PAYLOAD],
self._config[CONF_HOST],
self._config[CONF_PORT],
)
return
value = sock.recv(self._config[CONF_BUFFER_SIZE]).decode()
value_template = self._config[CONF_VALUE_TEMPLATE]
if value_template is not None:
try:
self._state = value_template.render(parse_result=False, value=value)
return
except TemplateError:
_LOGGER.error(
"Unable to render template of %r with value: %r",
self._config[CONF_VALUE_TEMPLATE],
value,
)
return
self._state = value

View file

@ -20,9 +20,9 @@ TEST_ENTITY = "binary_sensor.test_name"
def mock_socket_fixture():
"""Mock the socket."""
with patch(
"homeassistant.components.tcp.sensor.socket.socket"
"homeassistant.components.tcp.common.socket.socket"
) as mock_socket, patch(
"homeassistant.components.tcp.sensor.select.select",
"homeassistant.components.tcp.common.select.select",
return_value=(True, False, False),
):
# yield the return value of the socket context manager

View file

@ -4,7 +4,7 @@ from unittest.mock import call, patch
import pytest
import homeassistant.components.tcp.sensor as tcp
import homeassistant.components.tcp.common as tcp
from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component
@ -41,7 +41,7 @@ socket_test_value = "value"
@pytest.fixture(name="mock_socket")
def mock_socket_fixture(mock_select):
"""Mock socket."""
with patch("homeassistant.components.tcp.sensor.socket.socket") as mock_socket:
with patch("homeassistant.components.tcp.common.socket.socket") as mock_socket:
socket_instance = mock_socket.return_value.__enter__.return_value
socket_instance.recv.return_value = socket_test_value.encode()
yield socket_instance
@ -51,7 +51,7 @@ def mock_socket_fixture(mock_select):
def mock_select_fixture():
"""Mock select."""
with patch(
"homeassistant.components.tcp.sensor.select.select",
"homeassistant.components.tcp.common.select.select",
return_value=(True, False, False),
) as mock_select:
yield mock_select
@ -61,7 +61,7 @@ def mock_select_fixture():
def mock_ssl_context_fixture():
"""Mock select."""
with patch(
"homeassistant.components.tcp.sensor.ssl.create_default_context",
"homeassistant.components.tcp.common.ssl.create_default_context",
) as mock_ssl_context:
mock_ssl_context.return_value.wrap_socket.return_value.recv.return_value = (
socket_test_value + "_ssl"