From 7a01d33814a6a6d684c0b937657c8abeee66f6cf Mon Sep 17 00:00:00 2001 From: Xiaonan Shen Date: Wed, 20 Jan 2021 23:40:23 +0800 Subject: [PATCH] Add empty password support to pi-hole (#37958) --- homeassistant/components/pi_hole/__init__.py | 9 +- .../components/pi_hole/config_flow.py | 95 ++++++++++++------- homeassistant/components/pi_hole/const.py | 2 + homeassistant/components/pi_hole/strings.json | 6 ++ tests/components/pi_hole/__init__.py | 18 +++- tests/components/pi_hole/test_config_flow.py | 53 +++++++++-- tests/components/pi_hole/test_init.py | 36 +++++++ 7 files changed, 174 insertions(+), 45 deletions(-) diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index c9b7937da73..2d540d936e5 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -26,6 +26,7 @@ from homeassistant.helpers.update_coordinator import ( from .const import ( CONF_LOCATION, + CONF_STATISTICS_ONLY, DATA_KEY_API, DATA_KEY_COORDINATOR, DEFAULT_LOCATION, @@ -83,6 +84,12 @@ async def async_setup_entry(hass, entry): location = entry.data[CONF_LOCATION] api_key = entry.data.get(CONF_API_KEY) + # For backward compatibility + if CONF_STATISTICS_ONLY not in entry.data: + hass.config_entries.async_update_entry( + entry, data={**entry.data, CONF_STATISTICS_ONLY: not api_key} + ) + _LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) try: @@ -146,7 +153,7 @@ async def async_unload_entry(hass, entry): def _async_platforms(entry): """Return platforms to be loaded / unloaded.""" platforms = ["sensor"] - if entry.data.get(CONF_API_KEY): + if not entry.data[CONF_STATISTICS_ONLY]: platforms.append("switch") else: platforms.append("binary_sensor") diff --git a/homeassistant/components/pi_hole/config_flow.py b/homeassistant/components/pi_hole/config_flow.py index c7061b05caa..a7d4b387b1c 100644 --- a/homeassistant/components/pi_hole/config_flow.py +++ b/homeassistant/components/pi_hole/config_flow.py @@ -8,9 +8,11 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.components.pi_hole.const import ( # pylint: disable=unused-import CONF_LOCATION, + CONF_STATISTICS_ONLY, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, + DEFAULT_STATISTICS_ONLY, DEFAULT_VERIFY_SSL, DOMAIN, ) @@ -33,6 +35,10 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL + def __init__(self): + """Initialize the config flow.""" + self._config = None + async def async_step_user(self, user_input=None): """Handle a flow initiated by the user.""" return await self.async_step_init(user_input) @@ -55,67 +61,93 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): location = user_input[CONF_LOCATION] tls = user_input[CONF_SSL] verify_tls = user_input[CONF_VERIFY_SSL] - api_token = user_input.get(CONF_API_KEY) endpoint = f"{host}/{location}" if await self._async_endpoint_existed(endpoint): return self.async_abort(reason="already_configured") try: - await self._async_try_connect( - host, location, tls, verify_tls, api_token - ) - return self.async_create_entry( - title=name, - data={ - CONF_HOST: host, - CONF_NAME: name, - CONF_LOCATION: location, - CONF_SSL: tls, - CONF_VERIFY_SSL: verify_tls, - CONF_API_KEY: api_token, - }, - ) + await self._async_try_connect(host, location, tls, verify_tls) except HoleError as ex: _LOGGER.debug("Connection failed: %s", ex) if is_import: _LOGGER.error("Failed to import: %s", ex) return self.async_abort(reason="cannot_connect") errors["base"] = "cannot_connect" + else: + self._config = { + CONF_HOST: host, + CONF_NAME: name, + CONF_LOCATION: location, + CONF_SSL: tls, + CONF_VERIFY_SSL: verify_tls, + } + if is_import: + api_key = user_input.get(CONF_API_KEY) + return self.async_create_entry( + title=name, + data={ + **self._config, + CONF_STATISTICS_ONLY: api_key is None, + CONF_API_KEY: api_key, + }, + ) + self._config[CONF_STATISTICS_ONLY] = user_input[CONF_STATISTICS_ONLY] + if self._config[CONF_STATISTICS_ONLY]: + return self.async_create_entry(title=name, data=self._config) + return await self.async_step_api_key() user_input = user_input or {} return self.async_show_form( step_id="user", data_schema=vol.Schema( { + vol.Required(CONF_HOST, default=user_input.get(CONF_HOST, "")): str, vol.Required( - CONF_HOST, default=user_input.get(CONF_HOST) or "" - ): str, - vol.Required( - CONF_PORT, default=user_input.get(CONF_PORT) or 80 + CONF_PORT, default=user_input.get(CONF_PORT, 80) ): vol.Coerce(int), vol.Required( - CONF_NAME, default=user_input.get(CONF_NAME) or DEFAULT_NAME + CONF_NAME, default=user_input.get(CONF_NAME, DEFAULT_NAME) ): str, vol.Required( CONF_LOCATION, - default=user_input.get(CONF_LOCATION) or DEFAULT_LOCATION, - ): str, - vol.Optional( - CONF_API_KEY, default=user_input.get(CONF_API_KEY) or "" + default=user_input.get(CONF_LOCATION, DEFAULT_LOCATION), ): str, vol.Required( - CONF_SSL, default=user_input.get(CONF_SSL) or DEFAULT_SSL + CONF_STATISTICS_ONLY, + default=user_input.get( + CONF_STATISTICS_ONLY, DEFAULT_STATISTICS_ONLY + ), + ): bool, + vol.Required( + CONF_SSL, + default=user_input.get(CONF_SSL, DEFAULT_SSL), ): bool, vol.Required( CONF_VERIFY_SSL, - default=user_input.get(CONF_VERIFY_SSL) or DEFAULT_VERIFY_SSL, + default=user_input.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL), ): bool, } ), errors=errors, ) + async def async_step_api_key(self, user_input=None): + """Handle step to setup API key.""" + if user_input is not None: + return self.async_create_entry( + title=self._config[CONF_NAME], + data={ + **self._config, + CONF_API_KEY: user_input.get(CONF_API_KEY, ""), + }, + ) + + return self.async_show_form( + step_id="api_key", + data_schema=vol.Schema({vol.Optional(CONF_API_KEY): str}), + ) + async def _async_endpoint_existed(self, endpoint): existing_endpoints = [ f"{entry.data.get(CONF_HOST)}/{entry.data.get(CONF_LOCATION)}" @@ -123,14 +155,7 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ] return endpoint in existing_endpoints - async def _async_try_connect(self, host, location, tls, verify_tls, api_token): + async def _async_try_connect(self, host, location, tls, verify_tls): session = async_get_clientsession(self.hass, verify_tls) - pi_hole = Hole( - host, - self.hass.loop, - session, - location=location, - tls=tls, - api_token=api_token, - ) + pi_hole = Hole(host, self.hass.loop, session, location=location, tls=tls) await pi_hole.get_data() diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index b15db5f3980..f1871bf27c8 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -6,12 +6,14 @@ from homeassistant.const import PERCENTAGE DOMAIN = "pi_hole" CONF_LOCATION = "location" +CONF_STATISTICS_ONLY = "statistics_only" DEFAULT_LOCATION = "admin" DEFAULT_METHOD = "GET" DEFAULT_NAME = "Pi-Hole" DEFAULT_SSL = False DEFAULT_VERIFY_SSL = True +DEFAULT_STATISTICS_ONLY = True SERVICE_DISABLE = "disable" SERVICE_DISABLE_ATTR_DURATION = "duration" diff --git a/homeassistant/components/pi_hole/strings.json b/homeassistant/components/pi_hole/strings.json index 75af03dc3a5..fbf3c5a627b 100644 --- a/homeassistant/components/pi_hole/strings.json +++ b/homeassistant/components/pi_hole/strings.json @@ -8,9 +8,15 @@ "name": "[%key:common::config_flow::data::name%]", "location": "[%key:common::config_flow::data::location%]", "api_key": "[%key:common::config_flow::data::api_key%]", + "statistics_only": "Statistics Only", "ssl": "[%key:common::config_flow::data::ssl%]", "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]" } + }, + "api_key": { + "data": { + "api_key": "[%key:common::config_flow::data::api_key%]" + } } }, "error": { diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index f2a040f615a..f02cd0c8a7a 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from hole.exceptions import HoleError -from homeassistant.components.pi_hole.const import CONF_LOCATION +from homeassistant.components.pi_hole.const import CONF_LOCATION, CONF_STATISTICS_ONLY from homeassistant.const import ( CONF_API_KEY, CONF_HOST, @@ -43,11 +43,25 @@ CONF_DATA = { CONF_VERIFY_SSL: VERIFY_SSL, } -CONF_CONFIG_FLOW = { +CONF_CONFIG_FLOW_USER = { CONF_HOST: HOST, CONF_PORT: PORT, CONF_LOCATION: LOCATION, CONF_NAME: NAME, + CONF_STATISTICS_ONLY: False, + CONF_SSL: SSL, + CONF_VERIFY_SSL: VERIFY_SSL, +} + +CONF_CONFIG_FLOW_API_KEY = { + CONF_API_KEY: API_KEY, +} + +CONF_CONFIG_ENTRY = { + CONF_HOST: f"{HOST}:{PORT}", + CONF_LOCATION: LOCATION, + CONF_NAME: NAME, + CONF_STATISTICS_ONLY: False, CONF_API_KEY: API_KEY, CONF_SSL: SSL, CONF_VERIFY_SSL: VERIFY_SSL, diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py index 28589ab0193..517697b0e8a 100644 --- a/tests/components/pi_hole/test_config_flow.py +++ b/tests/components/pi_hole/test_config_flow.py @@ -2,8 +2,9 @@ import logging from unittest.mock import patch -from homeassistant.components.pi_hole.const import DOMAIN +from homeassistant.components.pi_hole.const import CONF_STATISTICS_ONLY, DOMAIN from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER +from homeassistant.const import CONF_API_KEY from homeassistant.data_entry_flow import ( RESULT_TYPE_ABORT, RESULT_TYPE_CREATE_ENTRY, @@ -11,7 +12,9 @@ from homeassistant.data_entry_flow import ( ) from . import ( - CONF_CONFIG_FLOW, + CONF_CONFIG_ENTRY, + CONF_CONFIG_FLOW_API_KEY, + CONF_CONFIG_FLOW_USER, CONF_DATA, NAME, _create_mocked_hole, @@ -43,7 +46,7 @@ async def test_flow_import(hass, caplog): ) assert result["type"] == RESULT_TYPE_CREATE_ENTRY assert result["title"] == NAME - assert result["data"] == CONF_DATA + assert result["data"] == CONF_CONFIG_ENTRY # duplicated server result = await hass.config_entries.flow.async_init( @@ -80,28 +83,64 @@ async def test_flow_user(hass): result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=CONF_CONFIG_FLOW, + user_input=CONF_CONFIG_FLOW_USER, + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["step_id"] == "api_key" + assert result["errors"] is None + _flow_next(hass, result["flow_id"]) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=CONF_CONFIG_FLOW_API_KEY, ) assert result["type"] == RESULT_TYPE_CREATE_ENTRY assert result["title"] == NAME - assert result["data"] == CONF_DATA + assert result["data"] == CONF_CONFIG_ENTRY # duplicated server result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=CONF_CONFIG_FLOW, + data=CONF_CONFIG_FLOW_USER, ) assert result["type"] == RESULT_TYPE_ABORT assert result["reason"] == "already_configured" +async def test_flow_statistics_only(hass): + """Test user initialized flow with statistics only.""" + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_setup(): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["step_id"] == "user" + assert result["errors"] == {} + _flow_next(hass, result["flow_id"]) + + user_input = {**CONF_CONFIG_FLOW_USER} + user_input[CONF_STATISTICS_ONLY] = True + config_entry_data = {**CONF_CONFIG_ENTRY} + config_entry_data[CONF_STATISTICS_ONLY] = True + config_entry_data.pop(CONF_API_KEY) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + assert result["type"] == RESULT_TYPE_CREATE_ENTRY + assert result["title"] == NAME + assert result["data"] == config_entry_data + + async def test_flow_user_invalid(hass): """Test user initialized flow with invalid server.""" mocked_hole = _create_mocked_hole(True) with _patch_config_flow_hole(mocked_hole), _patch_setup(): result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW + DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW_USER ) assert result["type"] == RESULT_TYPE_FORM assert result["step_id"] == "user" diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index ef462270954..a14e155b3da 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -7,6 +7,7 @@ from hole.exceptions import HoleError from homeassistant.components import pi_hole, switch from homeassistant.components.pi_hole.const import ( CONF_LOCATION, + CONF_STATISTICS_ONLY, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, @@ -16,6 +17,7 @@ from homeassistant.components.pi_hole.const import ( ) from homeassistant.const import ( ATTR_ENTITY_ID, + CONF_API_KEY, CONF_HOST, CONF_NAME, CONF_SSL, @@ -24,6 +26,8 @@ from homeassistant.const import ( from homeassistant.setup import async_setup_component from . import ( + CONF_CONFIG_ENTRY, + CONF_DATA, SWITCH_ENTITY_ID, _create_mocked_hole, _patch_config_flow_hole, @@ -196,6 +200,7 @@ async def test_unload(hass): CONF_LOCATION: DEFAULT_LOCATION, CONF_SSL: DEFAULT_SSL, CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, + CONF_STATISTICS_ONLY: True, }, ) entry.add_to_hass(hass) @@ -208,3 +213,34 @@ async def test_unload(hass): assert await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() assert entry.entry_id not in hass.data[pi_hole.DOMAIN] + + +async def test_migrate(hass): + """Test migrate from old config entry.""" + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA) + entry.add_to_hass(hass) + + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + assert entry.data == CONF_CONFIG_ENTRY + + +async def test_migrate_statistics_only(hass): + """Test migrate from old config entry with statistics only.""" + conf_data = {**CONF_DATA} + conf_data[CONF_API_KEY] = "" + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=conf_data) + entry.add_to_hass(hass) + + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + config_entry_data = {**CONF_CONFIG_ENTRY} + config_entry_data[CONF_STATISTICS_ONLY] = True + config_entry_data[CONF_API_KEY] = "" + assert entry.data == config_entry_data