From 8567fe94e1838f4a5c04f564e296af8f4b773eac Mon Sep 17 00:00:00 2001 From: Rob Bierbooms Date: Sat, 5 Sep 2020 12:05:46 +0200 Subject: [PATCH] Add connection validation on import for dsmr integration (#39664) --- homeassistant/components/dsmr/config_flow.py | 125 ++++++++++++++- homeassistant/components/dsmr/const.py | 3 + tests/components/dsmr/test_config_flow.py | 158 ++++++++++++++++++- tests/components/dsmr/test_sensor.py | 7 +- 4 files changed, 283 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/dsmr/config_flow.py b/homeassistant/components/dsmr/config_flow.py index d3aa770ff60..d0d0304a02a 100644 --- a/homeassistant/components/dsmr/config_flow.py +++ b/homeassistant/components/dsmr/config_flow.py @@ -1,15 +1,114 @@ """Config flow for DSMR integration.""" +import asyncio +from functools import partial import logging from typing import Any, Dict, Optional -from homeassistant import config_entries +from async_timeout import timeout +from dsmr_parser import obis_references as obis_ref +from dsmr_parser.clients.protocol import create_dsmr_reader, create_tcp_dsmr_reader +import serial + +from homeassistant import config_entries, core, exceptions from homeassistant.const import CONF_HOST, CONF_PORT -from .const import DOMAIN # pylint:disable=unused-import +from .const import ( # pylint:disable=unused-import + CONF_DSMR_VERSION, + CONF_SERIAL_ID, + CONF_SERIAL_ID_GAS, + DOMAIN, +) _LOGGER = logging.getLogger(__name__) +class DSMRConnection: + """Test the connection to DSMR and receive telegram to read serial ids.""" + + def __init__(self, host, port, dsmr_version): + """Initialize.""" + self._host = host + self._port = port + self._dsmr_version = dsmr_version + self._telegram = {} + + def equipment_identifier(self): + """Equipment identifier.""" + if obis_ref.EQUIPMENT_IDENTIFIER in self._telegram: + dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER] + return getattr(dsmr_object, "value", None) + + def equipment_identifier_gas(self): + """Equipment identifier gas.""" + if obis_ref.EQUIPMENT_IDENTIFIER_GAS in self._telegram: + dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER_GAS] + return getattr(dsmr_object, "value", None) + + async def validate_connect(self, hass: core.HomeAssistant) -> bool: + """Test if we can validate connection with the device.""" + + def update_telegram(telegram): + self._telegram = telegram + + transport.close() + + if self._host is None: + reader_factory = partial( + create_dsmr_reader, + self._port, + self._dsmr_version, + update_telegram, + loop=hass.loop, + ) + else: + reader_factory = partial( + create_tcp_dsmr_reader, + self._host, + self._port, + self._dsmr_version, + update_telegram, + loop=hass.loop, + ) + + try: + transport, protocol = await asyncio.create_task(reader_factory()) + except (serial.serialutil.SerialException, OSError): + _LOGGER.exception("Error connecting to DSMR") + return False + + if transport: + try: + async with timeout(30): + await protocol.wait_closed() + except asyncio.TimeoutError: + # Timeout (no data received), close transport and return True (if telegram is empty, will result in CannotCommunicate error) + transport.close() + await protocol.wait_closed() + return True + + +async def _validate_dsmr_connection(hass: core.HomeAssistant, data): + """Validate the user input allows us to connect.""" + conn = DSMRConnection(data.get(CONF_HOST), data[CONF_PORT], data[CONF_DSMR_VERSION]) + + if not await conn.validate_connect(hass): + raise CannotConnect + + equipment_identifier = conn.equipment_identifier() + equipment_identifier_gas = conn.equipment_identifier_gas() + + # Check only for equipment identifier in case no gas meter is connected + if equipment_identifier is None: + raise CannotCommunicate + + info = { + CONF_SERIAL_ID: equipment_identifier, + CONF_SERIAL_ID_GAS: equipment_identifier_gas, + } + + return info + + class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for DSMR.""" @@ -55,9 +154,29 @@ class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if status is not None: return status + try: + info = await _validate_dsmr_connection(self.hass, import_config) + except CannotConnect: + return self.async_abort(reason="cannot_connect") + except CannotCommunicate: + return self.async_abort(reason="cannot_communicate") + if host is not None: name = f"{host}:{port}" else: name = port - return self.async_create_entry(title=name, data=import_config) + data = {**import_config, **info} + + await self.async_set_unique_id(info[CONF_SERIAL_ID]) + self._abort_if_unique_id_configured(data) + + return self.async_create_entry(title=name, data=data) + + +class CannotConnect(exceptions.HomeAssistantError): + """Error to indicate we cannot connect.""" + + +class CannotCommunicate(exceptions.HomeAssistantError): + """Error to indicate we cannot connect.""" diff --git a/homeassistant/components/dsmr/const.py b/homeassistant/components/dsmr/const.py index 110e6b46a99..ed5f8bf0ed7 100644 --- a/homeassistant/components/dsmr/const.py +++ b/homeassistant/components/dsmr/const.py @@ -8,6 +8,9 @@ CONF_DSMR_VERSION = "dsmr_version" CONF_RECONNECT_INTERVAL = "reconnect_interval" CONF_PRECISION = "precision" +CONF_SERIAL_ID = "serial_id" +CONF_SERIAL_ID_GAS = "serial_id_gas" + DEFAULT_DSMR_VERSION = "2.2" DEFAULT_PORT = "/dev/ttyUSB0" DEFAULT_PRECISION = 3 diff --git a/tests/components/dsmr/test_config_flow.py b/tests/components/dsmr/test_config_flow.py index 1d25d2cd915..c35562b4024 100644 --- a/tests/components/dsmr/test_config_flow.py +++ b/tests/components/dsmr/test_config_flow.py @@ -1,12 +1,65 @@ """Test the DSMR config flow.""" +import asyncio +from itertools import chain, repeat + +from dsmr_parser.clients.protocol import DSMRProtocol +from dsmr_parser.obis_references import EQUIPMENT_IDENTIFIER, EQUIPMENT_IDENTIFIER_GAS +from dsmr_parser.objects import CosemObject +import pytest +import serial + from homeassistant import config_entries, setup from homeassistant.components.dsmr import DOMAIN -from tests.async_mock import patch +from tests.async_mock import DEFAULT, AsyncMock, Mock, patch from tests.common import MockConfigEntry +SERIAL_DATA = {"serial_id": "12345678", "serial_id_gas": "123456789"} -async def test_import_usb(hass): + +@pytest.fixture +def mock_connection_factory(monkeypatch): + """Mock the create functions for serial and TCP Asyncio connections.""" + transport = Mock(spec=asyncio.Transport) + protocol = Mock(spec=DSMRProtocol) + + async def connection_factory(*args, **kwargs): + """Return mocked out Asyncio classes.""" + return (transport, protocol) + + connection_factory = Mock(wraps=connection_factory) + + # apply the mock to both connection factories + monkeypatch.setattr( + "homeassistant.components.dsmr.config_flow.create_dsmr_reader", + connection_factory, + ) + monkeypatch.setattr( + "homeassistant.components.dsmr.config_flow.create_tcp_dsmr_reader", + connection_factory, + ) + + protocol.telegram = { + EQUIPMENT_IDENTIFIER: CosemObject([{"value": "12345678", "unit": ""}]), + EQUIPMENT_IDENTIFIER_GAS: CosemObject([{"value": "123456789", "unit": ""}]), + } + + async def wait_closed(): + if isinstance(connection_factory.call_args_list[0][0][2], str): + # TCP + telegram_callback = connection_factory.call_args_list[0][0][3] + else: + # Serial + telegram_callback = connection_factory.call_args_list[0][0][2] + + telegram_callback(protocol.telegram) + + protocol.wait_closed = wait_closed + + return connection_factory, transport, protocol + + +async def test_import_usb(hass, mock_connection_factory): """Test we can import.""" await setup.async_setup_component(hass, "persistent_notification", {}) @@ -26,10 +79,103 @@ async def test_import_usb(hass): assert result["type"] == "create_entry" assert result["title"] == "/dev/ttyUSB0" - assert result["data"] == entry_data + assert result["data"] == {**entry_data, **SERIAL_DATA} -async def test_import_network(hass): +async def test_import_usb_failed_connection(hass, monkeypatch, mock_connection_factory): + """Test we can import.""" + (connection_factory, transport, protocol) = mock_connection_factory + + await setup.async_setup_component(hass, "persistent_notification", {}) + + entry_data = { + "port": "/dev/ttyUSB0", + "dsmr_version": "2.2", + "precision": 4, + "reconnect_interval": 30, + } + + # override the mock to have it fail the first time and succeed after + first_fail_connection_factory = AsyncMock( + return_value=(transport, protocol), + side_effect=chain([serial.serialutil.SerialException], repeat(DEFAULT)), + ) + + monkeypatch.setattr( + "homeassistant.components.dsmr.config_flow.create_dsmr_reader", + first_fail_connection_factory, + ) + + with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_IMPORT}, + data=entry_data, + ) + + assert result["type"] == "abort" + assert result["reason"] == "cannot_connect" + + +async def test_import_usb_no_data(hass, monkeypatch, mock_connection_factory): + """Test we can import.""" + (connection_factory, transport, protocol) = mock_connection_factory + + await setup.async_setup_component(hass, "persistent_notification", {}) + + entry_data = { + "port": "/dev/ttyUSB0", + "dsmr_version": "2.2", + "precision": 4, + "reconnect_interval": 30, + } + + # override the mock to have it fail the first time and succeed after + wait_closed = AsyncMock( + return_value=(transport, protocol), + side_effect=chain([asyncio.TimeoutError], repeat(DEFAULT)), + ) + + protocol.wait_closed = wait_closed + + with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_IMPORT}, + data=entry_data, + ) + + assert result["type"] == "abort" + assert result["reason"] == "cannot_communicate" + + +async def test_import_usb_wrong_telegram(hass, mock_connection_factory): + """Test we can import.""" + (connection_factory, transport, protocol) = mock_connection_factory + + await setup.async_setup_component(hass, "persistent_notification", {}) + + entry_data = { + "port": "/dev/ttyUSB0", + "dsmr_version": "2.2", + "precision": 4, + "reconnect_interval": 30, + } + + protocol.telegram = {} + + with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_IMPORT}, + data=entry_data, + ) + + assert result["type"] == "abort" + assert result["reason"] == "cannot_communicate" + + +async def test_import_network(hass, mock_connection_factory): """Test we can import from network.""" await setup.async_setup_component(hass, "persistent_notification", {}) @@ -50,10 +196,10 @@ async def test_import_network(hass): assert result["type"] == "create_entry" assert result["title"] == "localhost:1234" - assert result["data"] == entry_data + assert result["data"] == {**entry_data, **SERIAL_DATA} -async def test_import_update(hass): +async def test_import_update(hass, mock_connection_factory): """Test we can import.""" await setup.async_setup_component(hass, "persistent_notification", {}) diff --git a/tests/components/dsmr/test_sensor.py b/tests/components/dsmr/test_sensor.py index 73c11579070..f0ff2f85c57 100644 --- a/tests/components/dsmr/test_sensor.py +++ b/tests/components/dsmr/test_sensor.py @@ -61,8 +61,13 @@ async def test_setup_platform(hass, mock_connection_factory): "reconnect_interval": 30, } + serial_data = {"serial_id": "1234", "serial_id_gas": "5678"} + with patch("homeassistant.components.dsmr.async_setup", return_value=True), patch( "homeassistant.components.dsmr.async_setup_entry", return_value=True + ), patch( + "homeassistant.components.dsmr.config_flow._validate_dsmr_connection", + return_value=serial_data, ): assert await async_setup_component( hass, SENSOR_DOMAIN, {SENSOR_DOMAIN: entry_data} @@ -79,7 +84,7 @@ async def test_setup_platform(hass, mock_connection_factory): entry = conf_entries[0] assert entry.state == "loaded" - assert entry.data == entry_data + assert entry.data == {**entry_data, **serial_data} async def test_default_setup(hass, mock_connection_factory):