From df3f7687d4157bad18a87b0442d6590152eebb48 Mon Sep 17 00:00:00 2001 From: jjlawren Date: Mon, 2 Mar 2020 07:44:24 -0600 Subject: [PATCH] Refactor Certificate Expiry Sensor (#32066) * Cert Expiry refactor * Unused parameter * Reduce delay * Deprecate 'name' config * Use config entry unique_id * Fix logic bugs found with tests * Rewrite tests to use config flow core interfaces, validate created sensors * Update strings * Minor consistency fix * Review fixes, complete test coverage * Move error handling to helper * Subclass exceptions * Better tests * Use first object reference * Fix docstring --- .coveragerc | 1 - .../components/cert_expiry/config_flow.py | 95 +++--- homeassistant/components/cert_expiry/const.py | 1 - .../components/cert_expiry/errors.py | 26 ++ .../components/cert_expiry/helper.py | 30 +- .../components/cert_expiry/sensor.py | 120 +++---- .../components/cert_expiry/strings.json | 7 +- tests/components/cert_expiry/const.py | 3 + .../cert_expiry/test_config_flow.py | 298 +++++++++++------- tests/components/cert_expiry/test_init.py | 96 ++++++ tests/components/cert_expiry/test_sensors.py | 211 +++++++++++++ 11 files changed, 652 insertions(+), 236 deletions(-) create mode 100644 homeassistant/components/cert_expiry/errors.py create mode 100644 tests/components/cert_expiry/const.py create mode 100644 tests/components/cert_expiry/test_init.py create mode 100644 tests/components/cert_expiry/test_sensors.py diff --git a/.coveragerc b/.coveragerc index 56084a049a0..9bffb4350f9 100644 --- a/.coveragerc +++ b/.coveragerc @@ -109,7 +109,6 @@ omit = homeassistant/components/canary/alarm_control_panel.py homeassistant/components/canary/camera.py homeassistant/components/cast/* - homeassistant/components/cert_expiry/sensor.py homeassistant/components/cert_expiry/helper.py homeassistant/components/channels/* homeassistant/components/cisco_ios/device_tracker.py diff --git a/homeassistant/components/cert_expiry/config_flow.py b/homeassistant/components/cert_expiry/config_flow.py index f3bd2f07d63..3f77701906f 100644 --- a/homeassistant/components/cert_expiry/config_flow.py +++ b/homeassistant/components/cert_expiry/config_flow.py @@ -1,29 +1,23 @@ """Config flow for the Cert Expiry platform.""" import logging -import socket -import ssl import voluptuous as vol from homeassistant import config_entries -from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT -from homeassistant.core import HomeAssistant, callback +from homeassistant.const import CONF_HOST, CONF_PORT -from .const import DEFAULT_NAME, DEFAULT_PORT, DOMAIN -from .helper import get_cert +from .const import DEFAULT_PORT, DOMAIN # pylint: disable=unused-import +from .errors import ( + ConnectionRefused, + ConnectionTimeout, + ResolveFailed, + ValidationFailure, +) +from .helper import get_cert_time_to_expiry _LOGGER = logging.getLogger(__name__) -@callback -def certexpiry_entries(hass: HomeAssistant): - """Return the host,port tuples for the domain.""" - return set( - (entry.data[CONF_HOST], entry.data[CONF_PORT]) - for entry in hass.config_entries.async_entries(DOMAIN) - ) - - class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow.""" @@ -34,59 +28,47 @@ class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Initialize the config flow.""" self._errors = {} - def _prt_in_configuration_exists(self, user_input) -> bool: - """Return True if host, port combination exists in configuration.""" - host = user_input[CONF_HOST] - port = user_input.get(CONF_PORT, DEFAULT_PORT) - if (host, port) in certexpiry_entries(self.hass): - return True - return False - async def _test_connection(self, user_input=None): - """Test connection to the server and try to get the certtificate.""" - host = user_input[CONF_HOST] + """Test connection to the server and try to get the certificate.""" try: - await self.hass.async_add_executor_job( - get_cert, host, user_input.get(CONF_PORT, DEFAULT_PORT) + await get_cert_time_to_expiry( + self.hass, + user_input[CONF_HOST], + user_input.get(CONF_PORT, DEFAULT_PORT), ) return True - except socket.gaierror: - _LOGGER.error("Host cannot be resolved: %s", host) + except ResolveFailed: self._errors[CONF_HOST] = "resolve_failed" - except socket.timeout: - _LOGGER.error("Timed out connecting to %s", host) + except ConnectionTimeout: self._errors[CONF_HOST] = "connection_timeout" - except ssl.CertificateError as err: - if "doesn't match" in err.args[0]: - _LOGGER.error("Certificate does not match host: %s", host) - self._errors[CONF_HOST] = "wrong_host" - else: - _LOGGER.error("Certificate could not be validated: %s", host) - self._errors[CONF_HOST] = "certificate_error" - except ssl.SSLError: - _LOGGER.error("Certificate could not be validated: %s", host) - self._errors[CONF_HOST] = "certificate_error" + except ConnectionRefused: + self._errors[CONF_HOST] = "connection_refused" + except ValidationFailure: + return True return False async def async_step_user(self, user_input=None): """Step when user initializes a integration.""" self._errors = {} if user_input is not None: - # set some defaults in case we need to return to the form - if self._prt_in_configuration_exists(user_input): - self._errors[CONF_HOST] = "host_port_exists" - else: - if await self._test_connection(user_input): - return self.async_create_entry( - title=user_input.get(CONF_NAME, DEFAULT_NAME), - data={ - CONF_HOST: user_input[CONF_HOST], - CONF_PORT: user_input.get(CONF_PORT, DEFAULT_PORT), - }, - ) + host = user_input[CONF_HOST] + port = user_input.get(CONF_PORT, DEFAULT_PORT) + await self.async_set_unique_id(f"{host}:{port}") + self._abort_if_unique_id_configured() + + if await self._test_connection(user_input): + title_port = f":{port}" if port != DEFAULT_PORT else "" + title = f"{host}{title_port}" + return self.async_create_entry( + title=title, data={CONF_HOST: host, CONF_PORT: port}, + ) + if ( # pylint: disable=no-member + self.context["source"] == config_entries.SOURCE_IMPORT + ): + _LOGGER.error("Config import failed for %s", user_input[CONF_HOST]) + return self.async_abort(reason="import_failed") else: user_input = {} - user_input[CONF_NAME] = DEFAULT_NAME user_input[CONF_HOST] = "" user_input[CONF_PORT] = DEFAULT_PORT @@ -94,9 +76,6 @@ class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): step_id="user", data_schema=vol.Schema( { - vol.Required( - CONF_NAME, default=user_input.get(CONF_NAME, DEFAULT_NAME) - ): str, vol.Required(CONF_HOST, default=user_input[CONF_HOST]): str, vol.Required( CONF_PORT, default=user_input.get(CONF_PORT, DEFAULT_PORT) @@ -111,6 +90,4 @@ class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): Only host was required in the yaml file all other fields are optional """ - if self._prt_in_configuration_exists(user_input): - return self.async_abort(reason="host_port_exists") return await self.async_step_user(user_input) diff --git a/homeassistant/components/cert_expiry/const.py b/homeassistant/components/cert_expiry/const.py index 4129781f2a0..00d5ac9e923 100644 --- a/homeassistant/components/cert_expiry/const.py +++ b/homeassistant/components/cert_expiry/const.py @@ -1,6 +1,5 @@ """Const for Cert Expiry.""" DOMAIN = "cert_expiry" -DEFAULT_NAME = "SSL Certificate Expiry" DEFAULT_PORT = 443 TIMEOUT = 10.0 diff --git a/homeassistant/components/cert_expiry/errors.py b/homeassistant/components/cert_expiry/errors.py new file mode 100644 index 00000000000..a3b73c84f2a --- /dev/null +++ b/homeassistant/components/cert_expiry/errors.py @@ -0,0 +1,26 @@ +"""Errors for the cert_expiry integration.""" +from homeassistant.exceptions import HomeAssistantError + + +class CertExpiryException(HomeAssistantError): + """Base class for cert_expiry exceptions.""" + + +class TemporaryFailure(CertExpiryException): + """Temporary failure has occurred.""" + + +class ValidationFailure(CertExpiryException): + """Certificate validation failure has occurred.""" + + +class ResolveFailed(TemporaryFailure): + """Name resolution failed.""" + + +class ConnectionTimeout(TemporaryFailure): + """Network connection timed out.""" + + +class ConnectionRefused(TemporaryFailure): + """Network connection refused.""" diff --git a/homeassistant/components/cert_expiry/helper.py b/homeassistant/components/cert_expiry/helper.py index cd49588ec89..bb9f2762f3a 100644 --- a/homeassistant/components/cert_expiry/helper.py +++ b/homeassistant/components/cert_expiry/helper.py @@ -1,12 +1,19 @@ """Helper functions for the Cert Expiry platform.""" +from datetime import datetime import socket import ssl from .const import TIMEOUT +from .errors import ( + ConnectionRefused, + ConnectionTimeout, + ResolveFailed, + ValidationFailure, +) def get_cert(host, port): - """Get the ssl certificate for the host and port combination.""" + """Get the certificate for the host and port combination.""" ctx = ssl.create_default_context() address = (host, port) with socket.create_connection(address, timeout=TIMEOUT) as sock: @@ -14,3 +21,24 @@ def get_cert(host, port): # pylint disable: https://github.com/PyCQA/pylint/issues/3166 cert = ssock.getpeercert() # pylint: disable=no-member return cert + + +async def get_cert_time_to_expiry(hass, hostname, port): + """Return the certificate's time to expiry in days.""" + try: + cert = await hass.async_add_executor_job(get_cert, hostname, port) + except socket.gaierror: + raise ResolveFailed(f"Cannot resolve hostname: {hostname}") + except socket.timeout: + raise ConnectionTimeout(f"Connection timeout with server: {hostname}:{port}") + except ConnectionRefusedError: + raise ConnectionRefused(f"Connection refused by server: {hostname}:{port}") + except ssl.CertificateError as err: + raise ValidationFailure(err.verify_message) + except ssl.SSLError as err: + raise ValidationFailure(err.args[0]) + + ts_seconds = ssl.cert_time_to_seconds(cert["notAfter"]) + timestamp = datetime.fromtimestamp(ts_seconds) + expiry = timestamp - datetime.today() + return expiry.days diff --git a/homeassistant/components/cert_expiry/sensor.py b/homeassistant/components/cert_expiry/sensor.py index b4437ca5834..39ec2c35ac7 100644 --- a/homeassistant/components/cert_expiry/sensor.py +++ b/homeassistant/components/cert_expiry/sensor.py @@ -1,8 +1,6 @@ """Counter for the days until an HTTPS (TLS) certificate will expire.""" -from datetime import datetime, timedelta +from datetime import timedelta import logging -import socket -import ssl import voluptuous as vol @@ -16,47 +14,71 @@ from homeassistant.const import ( TIME_DAYS, ) from homeassistant.core import callback +from homeassistant.exceptions import PlatformNotReady import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity +from homeassistant.helpers.event import async_call_later -from .const import DEFAULT_NAME, DEFAULT_PORT, DOMAIN -from .helper import get_cert +from .const import DEFAULT_PORT, DOMAIN +from .errors import TemporaryFailure, ValidationFailure +from .helper import get_cert_time_to_expiry _LOGGER = logging.getLogger(__name__) SCAN_INTERVAL = timedelta(hours=12) -PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_HOST): cv.string, - vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, - } +PLATFORM_SCHEMA = vol.All( + cv.deprecated(CONF_NAME, invalidation_version="0.109"), + PLATFORM_SCHEMA.extend( + { + vol.Required(CONF_HOST): cv.string, + vol.Optional(CONF_NAME): cv.string, + vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, + } + ), ) async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): """Set up certificate expiry sensor.""" + @callback + def schedule_import(_): + """Schedule delayed import after HA is fully started.""" + async_call_later(hass, 10, do_import) + @callback def do_import(_): - """Process YAML import after HA is fully started.""" + """Process YAML import.""" hass.async_create_task( hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=dict(config) ) ) - # Delay to avoid validation during setup in case we're checking our own cert. - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, do_import) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, schedule_import) async def async_setup_entry(hass, entry, async_add_entities): """Add cert-expiry entry.""" + days = 0 + error = None + hostname = entry.data[CONF_HOST] + port = entry.data[CONF_PORT] + + if entry.unique_id is None: + hass.config_entries.async_update_entry(entry, unique_id=f"{hostname}:{port}") + + try: + days = await get_cert_time_to_expiry(hass, hostname, port) + except TemporaryFailure as err: + _LOGGER.error(err) + raise PlatformNotReady + except ValidationFailure as err: + error = err + async_add_entities( - [SSLCertificate(entry.title, entry.data[CONF_HOST], entry.data[CONF_PORT])], - False, - # Don't update in case we're checking our own cert. + [SSLCertificate(hostname, port, days, error)], False, ) return True @@ -64,14 +86,18 @@ async def async_setup_entry(hass, entry, async_add_entities): class SSLCertificate(Entity): """Implementation of the certificate expiry sensor.""" - def __init__(self, sensor_name, server_name, server_port): + def __init__(self, server_name, server_port, days, error): """Initialize the sensor.""" self.server_name = server_name self.server_port = server_port - self._name = sensor_name - self._state = None - self._available = False + display_port = f":{server_port}" if server_port != DEFAULT_PORT else "" + self._name = f"Cert Expiry ({self.server_name}{display_port})" + self._available = True + self._error = error + self._state = days self._valid = False + if error is None: + self._valid = True @property def name(self): @@ -103,50 +129,38 @@ class SSLCertificate(Entity): """Return the availability of the sensor.""" return self._available - async def async_added_to_hass(self): - """Once the entity is added we should update to get the initial data loaded.""" - - @callback - def do_update(_): - """Run the update method when the start event was fired.""" - self.async_schedule_update_ha_state(True) - - if self.hass.is_running: - self.async_schedule_update_ha_state(True) - else: - # Delay until HA is fully started in case we're checking our own cert. - self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, do_update) - - def update(self): + async def async_update(self): """Fetch the certificate information.""" try: - cert = get_cert(self.server_name, self.server_port) - except socket.gaierror: - _LOGGER.error("Cannot resolve hostname: %s", self.server_name) + days_to_expiry = await get_cert_time_to_expiry( + self.hass, self.server_name, self.server_port + ) + except TemporaryFailure as err: + _LOGGER.error(err.args[0]) self._available = False - self._valid = False return - except socket.timeout: - _LOGGER.error("Connection timeout with server: %s", self.server_name) - self._available = False - self._valid = False - return - except (ssl.CertificateError, ssl.SSLError): + except ValidationFailure as err: + _LOGGER.error( + "Certificate validation error: %s [%s]", self.server_name, err + ) self._available = True + self._error = err self._state = 0 self._valid = False return + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Unknown error checking %s:%s", self.server_name, self.server_port + ) + self._available = False + return - ts_seconds = ssl.cert_time_to_seconds(cert["notAfter"]) - timestamp = datetime.fromtimestamp(ts_seconds) - expiry = timestamp - datetime.today() self._available = True - self._state = expiry.days + self._error = None + self._state = days_to_expiry self._valid = True @property def device_state_attributes(self): """Return additional sensor state attributes.""" - attr = {"is_valid": self._valid} - - return attr + return {"is_valid": self._valid, "error": str(self._error)} diff --git a/homeassistant/components/cert_expiry/strings.json b/homeassistant/components/cert_expiry/strings.json index e5e670d214f..4d4982a19af 100644 --- a/homeassistant/components/cert_expiry/strings.json +++ b/homeassistant/components/cert_expiry/strings.json @@ -12,14 +12,13 @@ } }, "error": { - "host_port_exists": "This host and port combination is already configured", "resolve_failed": "This host can not be resolved", "connection_timeout": "Timeout when connecting to this host", - "certificate_error": "Certificate could not be validated", - "wrong_host": "Certificate does not match hostname" + "connection_refused": "Connection refused when connecting to host" }, "abort": { - "host_port_exists": "This host and port combination is already configured" + "already_configured": "This host and port combination is already configured", + "import_failed": "Import from config failed" } } } diff --git a/tests/components/cert_expiry/const.py b/tests/components/cert_expiry/const.py new file mode 100644 index 00000000000..9ddbeca61c3 --- /dev/null +++ b/tests/components/cert_expiry/const.py @@ -0,0 +1,3 @@ +"""Constants for cert_expiry tests.""" +PORT = 443 +HOST = "example.com" diff --git a/tests/components/cert_expiry/test_config_flow.py b/tests/components/cert_expiry/test_config_flow.py index 71005672fdb..1b2cc175dcb 100644 --- a/tests/components/cert_expiry/test_config_flow.py +++ b/tests/components/cert_expiry/test_config_flow.py @@ -1,154 +1,218 @@ """Tests for the Cert Expiry config flow.""" import socket import ssl -from unittest.mock import patch -import pytest +from asynctest import patch from homeassistant import data_entry_flow -from homeassistant.components.cert_expiry import config_flow -from homeassistant.components.cert_expiry.const import DEFAULT_NAME, DEFAULT_PORT +from homeassistant.components.cert_expiry.const import DEFAULT_PORT, DOMAIN from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT -from tests.common import MockConfigEntry, mock_coro +from .const import HOST, PORT -NAME = "Cert Expiry test 1 2 3" -PORT = 443 -HOST = "example.com" +from tests.common import MockConfigEntry -@pytest.fixture(name="test_connect") -def mock_controller(): - """Mock a successful _prt_in_configuration_exists.""" - with patch( - "homeassistant.components.cert_expiry.config_flow.CertexpiryConfigFlow._test_connection", - side_effect=lambda *_: mock_coro(True), - ): - yield - - -def init_config_flow(hass): - """Init a configuration flow.""" - flow = config_flow.CertexpiryConfigFlow() - flow.hass = hass - return flow - - -async def test_user(hass, test_connect): +async def test_user(hass): """Test user config.""" - flow = init_config_flow(hass) - - result = await flow.async_step_user() + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "user"} + ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" - # tets with all provided - result = await flow.async_step_user( - {CONF_NAME: NAME, CONF_HOST: HOST, CONF_PORT: PORT} - ) + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry" + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: HOST, CONF_PORT: PORT} + ) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == NAME + assert result["title"] == HOST assert result["data"][CONF_HOST] == HOST assert result["data"][CONF_PORT] == PORT + assert result["result"].unique_id == f"{HOST}:{PORT}" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() -async def test_import(hass, test_connect): - """Test import step.""" - flow = init_config_flow(hass) - - # import with only host - result = await flow.async_step_import({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == DEFAULT_NAME - assert result["data"][CONF_HOST] == HOST - assert result["data"][CONF_PORT] == DEFAULT_PORT - - # import with host and name - result = await flow.async_step_import({CONF_HOST: HOST, CONF_NAME: NAME}) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == NAME - assert result["data"][CONF_HOST] == HOST - assert result["data"][CONF_PORT] == DEFAULT_PORT - - # improt with host and port - result = await flow.async_step_import({CONF_HOST: HOST, CONF_PORT: PORT}) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == DEFAULT_NAME - assert result["data"][CONF_HOST] == HOST - assert result["data"][CONF_PORT] == PORT - - # import with all - result = await flow.async_step_import( - {CONF_HOST: HOST, CONF_PORT: PORT, CONF_NAME: NAME} - ) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == NAME - assert result["data"][CONF_HOST] == HOST - assert result["data"][CONF_PORT] == PORT - - -async def test_abort_if_already_setup(hass, test_connect): - """Test we abort if the cert is already setup.""" - flow = init_config_flow(hass) - MockConfigEntry( - domain="cert_expiry", - data={CONF_PORT: DEFAULT_PORT, CONF_NAME: NAME, CONF_HOST: HOST}, - ).add_to_hass(hass) - - # Should fail, same HOST and PORT (default) - result = await flow.async_step_import( - {CONF_HOST: HOST, CONF_NAME: NAME, CONF_PORT: DEFAULT_PORT} - ) - assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT - assert result["reason"] == "host_port_exists" - - # Should be the same HOST and PORT (default) - result = await flow.async_step_user( - {CONF_HOST: HOST, CONF_NAME: NAME, CONF_PORT: DEFAULT_PORT} +async def test_user_with_bad_cert(hass): + """Test user config with bad certificate.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "user"} ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "host_port_exists"} + assert result["step_id"] == "user" + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=ssl.SSLError("some error"), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: HOST, CONF_PORT: PORT} + ) - # SHOULD pass, same Host diff PORT - result = await flow.async_step_import( - {CONF_HOST: HOST, CONF_NAME: NAME, CONF_PORT: 888} - ) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == NAME + assert result["title"] == HOST + assert result["data"][CONF_HOST] == HOST + assert result["data"][CONF_PORT] == PORT + assert result["result"].unique_id == f"{HOST}:{PORT}" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() + + +async def test_import_host_only(hass): + """Test import with host only.""" + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry", + return_value=1, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "import"}, data={CONF_HOST: HOST} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"][CONF_HOST] == HOST + assert result["data"][CONF_PORT] == DEFAULT_PORT + assert result["result"].unique_id == f"{HOST}:{DEFAULT_PORT}" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() + + +async def test_import_host_and_port(hass): + """Test import with host and port.""" + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry", + return_value=1, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": "import"}, + data={CONF_HOST: HOST, CONF_PORT: PORT}, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"][CONF_HOST] == HOST + assert result["data"][CONF_PORT] == PORT + assert result["result"].unique_id == f"{HOST}:{PORT}" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() + + +async def test_import_non_default_port(hass): + """Test import with host and non-default port.""" + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry" + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "import"}, data={CONF_HOST: HOST, CONF_PORT: 888} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == f"{HOST}:888" assert result["data"][CONF_HOST] == HOST assert result["data"][CONF_PORT] == 888 + assert result["result"].unique_id == f"{HOST}:888" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() + + +async def test_import_with_name(hass): + """Test import with name (deprecated).""" + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry", + return_value=1, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": "import"}, + data={CONF_NAME: "legacy", CONF_HOST: HOST, CONF_PORT: PORT}, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == HOST + assert result["data"][CONF_HOST] == HOST + assert result["data"][CONF_PORT] == PORT + assert result["result"].unique_id == f"{HOST}:{PORT}" + + with patch("homeassistant.components.cert_expiry.sensor.async_setup_entry"): + await hass.async_block_till_done() + + +async def test_bad_import(hass): + """Test import step.""" + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=ConnectionRefusedError(), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "import"}, data={CONF_HOST: HOST} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "import_failed" + + +async def test_abort_if_already_setup(hass): + """Test we abort if the cert is already setup.""" + MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "import"}, data={CONF_HOST: HOST, CONF_PORT: PORT} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "user"}, data={CONF_HOST: HOST, CONF_PORT: PORT} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" async def test_abort_on_socket_failed(hass): """Test we abort of we have errors during socket creation.""" - flow = init_config_flow(hass) - - with patch("socket.create_connection", side_effect=socket.gaierror()): - result = await flow.async_step_user({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "resolve_failed"} - - with patch("socket.create_connection", side_effect=socket.timeout()): - result = await flow.async_step_user({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "connection_timeout"} + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "user"} + ) with patch( - "socket.create_connection", - side_effect=ssl.CertificateError(f"{HOST} doesn't match somethingelse.com"), + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=socket.gaierror(), ): - result = await flow.async_step_user({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "wrong_host"} + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: HOST} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {CONF_HOST: "resolve_failed"} with patch( - "socket.create_connection", side_effect=ssl.CertificateError("different error") + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=socket.timeout(), ): - result = await flow.async_step_user({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "certificate_error"} + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: HOST} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {CONF_HOST: "connection_timeout"} - with patch("socket.create_connection", side_effect=ssl.SSLError()): - result = await flow.async_step_user({CONF_HOST: HOST}) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_HOST: "certificate_error"} + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=ConnectionRefusedError, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: HOST} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {CONF_HOST: "connection_refused"} diff --git a/tests/components/cert_expiry/test_init.py b/tests/components/cert_expiry/test_init.py new file mode 100644 index 00000000000..d4419b48370 --- /dev/null +++ b/tests/components/cert_expiry/test_init.py @@ -0,0 +1,96 @@ +"""Tests for Cert Expiry setup.""" +from datetime import timedelta + +from asynctest import patch + +from homeassistant.components.cert_expiry.const import DOMAIN +from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN +from homeassistant.config_entries import ENTRY_STATE_LOADED, ENTRY_STATE_NOT_LOADED +from homeassistant.const import CONF_HOST, CONF_PORT, EVENT_HOMEASSISTANT_START +from homeassistant.setup import async_setup_component +import homeassistant.util.dt as dt_util + +from .const import HOST, PORT + +from tests.common import MockConfigEntry, async_fire_time_changed + + +async def test_setup_with_config(hass): + """Test setup component with config.""" + config = { + SENSOR_DOMAIN: [ + {"platform": DOMAIN, CONF_HOST: HOST, CONF_PORT: PORT}, + {"platform": DOMAIN, CONF_HOST: HOST, CONF_PORT: 888}, + ], + } + assert await async_setup_component(hass, SENSOR_DOMAIN, config) is True + await hass.async_block_till_done() + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + next_update = dt_util.utcnow() + timedelta(seconds=20) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.config_flow.get_cert_time_to_expiry", + return_value=100, + ), patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + await hass.async_block_till_done() + + assert len(hass.config_entries.async_entries(DOMAIN)) == 2 + + +async def test_update_unique_id(hass): + """Test updating a config entry without a unique_id.""" + entry = MockConfigEntry(domain=DOMAIN, data={CONF_HOST: HOST, CONF_PORT: PORT}) + entry.add_to_hass(hass) + + config_entries = hass.config_entries.async_entries(DOMAIN) + assert len(config_entries) == 1 + assert entry is config_entries[0] + assert not entry.unique_id + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + assert await async_setup_component(hass, DOMAIN, {}) is True + await hass.async_block_till_done() + + assert entry.state == ENTRY_STATE_LOADED + assert entry.unique_id == f"{HOST}:{PORT}" + + +async def test_unload_config_entry(hass): + """Test unloading a config entry.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + entry.add_to_hass(hass) + + config_entries = hass.config_entries.async_entries(DOMAIN) + assert len(config_entries) == 1 + assert entry is config_entries[0] + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + assert await async_setup_component(hass, DOMAIN, {}) is True + await hass.async_block_till_done() + + assert entry.state == ENTRY_STATE_LOADED + state = hass.states.get("sensor.cert_expiry_example_com") + assert state.state == "100" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + await hass.config_entries.async_unload(entry.entry_id) + + assert entry.state == ENTRY_STATE_NOT_LOADED + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is None diff --git a/tests/components/cert_expiry/test_sensors.py b/tests/components/cert_expiry/test_sensors.py new file mode 100644 index 00000000000..6594b0988e7 --- /dev/null +++ b/tests/components/cert_expiry/test_sensors.py @@ -0,0 +1,211 @@ +"""Tests for the Cert Expiry sensors.""" +from datetime import timedelta +import socket +import ssl + +from asynctest import patch + +from homeassistant.const import CONF_HOST, CONF_PORT, STATE_UNAVAILABLE +import homeassistant.util.dt as dt_util + +from .const import HOST, PORT + +from tests.common import MockConfigEntry, async_fire_time_changed + + +async def test_async_setup_entry(hass): + """Test async_setup_entry.""" + entry = MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "100" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + +async def test_async_setup_entry_bad_cert(hass): + """Test async_setup_entry with a bad/expired cert.""" + entry = MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=ssl.SSLError("some error"), + ): + entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "0" + assert state.attributes.get("error") == "some error" + assert not state.attributes.get("is_valid") + + +async def test_async_setup_entry_host_unavailable(hass): + """Test async_setup_entry when host is unavailable.""" + entry = MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=socket.gaierror, + ): + entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is None + + next_update = dt_util.utcnow() + timedelta(seconds=45) + async_fire_time_changed(hass, next_update) + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=socket.gaierror, + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is None + + +async def test_update_sensor(hass): + """Test async_update for sensor.""" + entry = MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "100" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + next_update = dt_util.utcnow() + timedelta(hours=12) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=99, + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "99" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + +async def test_update_sensor_network_errors(hass): + """Test async_update for sensor.""" + entry = MockConfigEntry( + domain="cert_expiry", + data={CONF_HOST: HOST, CONF_PORT: PORT}, + unique_id=f"{HOST}:{PORT}", + ) + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=100, + ): + entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "100" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + next_update = dt_util.utcnow() + timedelta(hours=12) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=socket.gaierror, + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state.state == STATE_UNAVAILABLE + + next_update = dt_util.utcnow() + timedelta(hours=12) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.sensor.get_cert_time_to_expiry", + return_value=99, + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "99" + assert state.attributes.get("error") == "None" + assert state.attributes.get("is_valid") + + next_update = dt_util.utcnow() + timedelta(hours=12) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", + side_effect=ssl.SSLError("something bad"), + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state is not None + assert state.state != STATE_UNAVAILABLE + assert state.state == "0" + assert state.attributes.get("error") == "something bad" + assert not state.attributes.get("is_valid") + + next_update = dt_util.utcnow() + timedelta(hours=12) + async_fire_time_changed(hass, next_update) + + with patch( + "homeassistant.components.cert_expiry.helper.get_cert", side_effect=Exception() + ): + await hass.async_block_till_done() + + state = hass.states.get("sensor.cert_expiry_example_com") + assert state.state == STATE_UNAVAILABLE