Add tls support for AVM Fritz!Tools (#112714)

This commit is contained in:
r-binder 2024-04-20 23:08:29 +02:00 committed by GitHub
parent b450918f66
commit 68225abce5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 210 additions and 51 deletions

View file

@ -3,13 +3,20 @@
import logging import logging
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
CONF_PORT,
CONF_SSL,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from .common import AvmWrapper, FritzData from .common import AvmWrapper, FritzData
from .const import ( from .const import (
DATA_FRITZ, DATA_FRITZ,
DEFAULT_SSL,
DOMAIN, DOMAIN,
FRITZ_AUTH_EXCEPTIONS, FRITZ_AUTH_EXCEPTIONS,
FRITZ_EXCEPTIONS, FRITZ_EXCEPTIONS,
@ -29,6 +36,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
port=entry.data[CONF_PORT], port=entry.data[CONF_PORT],
username=entry.data[CONF_USERNAME], username=entry.data[CONF_USERNAME],
password=entry.data[CONF_PASSWORD], password=entry.data[CONF_PASSWORD],
use_tls=entry.data.get(CONF_SSL, DEFAULT_SSL),
) )
try: try:

View file

@ -48,7 +48,7 @@ from .const import (
DEFAULT_CONF_OLD_DISCOVERY, DEFAULT_CONF_OLD_DISCOVERY,
DEFAULT_DEVICE_NAME, DEFAULT_DEVICE_NAME,
DEFAULT_HOST, DEFAULT_HOST,
DEFAULT_PORT, DEFAULT_SSL,
DEFAULT_USERNAME, DEFAULT_USERNAME,
DOMAIN, DOMAIN,
FRITZ_EXCEPTIONS, FRITZ_EXCEPTIONS,
@ -184,9 +184,10 @@ class FritzBoxTools(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
password: str, password: str,
port: int,
username: str = DEFAULT_USERNAME, username: str = DEFAULT_USERNAME,
host: str = DEFAULT_HOST, host: str = DEFAULT_HOST,
port: int = DEFAULT_PORT, use_tls: bool = DEFAULT_SSL,
) -> None: ) -> None:
"""Initialize FritzboxTools class.""" """Initialize FritzboxTools class."""
super().__init__( super().__init__(
@ -211,6 +212,7 @@ class FritzBoxTools(
self.password = password self.password = password
self.port = port self.port = port
self.username = username self.username = username
self.use_tls = use_tls
self.has_call_deflections: bool = False self.has_call_deflections: bool = False
self._model: str | None = None self._model: str | None = None
self._current_firmware: str | None = None self._current_firmware: str | None = None
@ -230,11 +232,13 @@ class FritzBoxTools(
def setup(self) -> None: def setup(self) -> None:
"""Set up FritzboxTools class.""" """Set up FritzboxTools class."""
self.connection = FritzConnection( self.connection = FritzConnection(
address=self.host, address=self.host,
port=self.port, port=self.port,
user=self.username, user=self.username,
password=self.password, password=self.password,
use_tls=self.use_tls,
timeout=60.0, timeout=60.0,
pool_maxsize=30, pool_maxsize=30,
) )

View file

@ -25,14 +25,22 @@ from homeassistant.config_entries import (
OptionsFlow, OptionsFlow,
OptionsFlowWithConfigEntry, OptionsFlowWithConfigEntry,
) )
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
CONF_PORT,
CONF_SSL,
CONF_USERNAME,
)
from homeassistant.core import callback from homeassistant.core import callback
from .const import ( from .const import (
CONF_OLD_DISCOVERY, CONF_OLD_DISCOVERY,
DEFAULT_CONF_OLD_DISCOVERY, DEFAULT_CONF_OLD_DISCOVERY,
DEFAULT_HOST, DEFAULT_HOST,
DEFAULT_PORT, DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_SSL,
DOMAIN, DOMAIN,
ERROR_AUTH_INVALID, ERROR_AUTH_INVALID,
ERROR_CANNOT_CONNECT, ERROR_CANNOT_CONNECT,
@ -61,6 +69,7 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
self._entry: ConfigEntry | None = None self._entry: ConfigEntry | None = None
self._name: str = "" self._name: str = ""
self._password: str = "" self._password: str = ""
self._use_tls: bool = False
self._port: int | None = None self._port: int | None = None
self._username: str = "" self._username: str = ""
self._model: str = "" self._model: str = ""
@ -74,6 +83,7 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
port=self._port, port=self._port,
user=self._username, user=self._username,
password=self._password, password=self._password,
use_tls=self._use_tls,
timeout=60.0, timeout=60.0,
pool_maxsize=30, pool_maxsize=30,
) )
@ -120,6 +130,7 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
CONF_PASSWORD: self._password, CONF_PASSWORD: self._password,
CONF_PORT: self._port, CONF_PORT: self._port,
CONF_USERNAME: self._username, CONF_USERNAME: self._username,
CONF_SSL: self._use_tls,
}, },
options={ options={
CONF_CONSIDER_HOME: DEFAULT_CONSIDER_HOME.total_seconds(), CONF_CONSIDER_HOME: DEFAULT_CONSIDER_HOME.total_seconds(),
@ -133,7 +144,6 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
ssdp_location: ParseResult = urlparse(discovery_info.ssdp_location or "") ssdp_location: ParseResult = urlparse(discovery_info.ssdp_location or "")
self._host = ssdp_location.hostname self._host = ssdp_location.hostname
self._port = ssdp_location.port
self._name = ( self._name = (
discovery_info.upnp.get(ssdp.ATTR_UPNP_FRIENDLY_NAME) discovery_info.upnp.get(ssdp.ATTR_UPNP_FRIENDLY_NAME)
or discovery_info.upnp[ssdp.ATTR_UPNP_MODEL_NAME] or discovery_info.upnp[ssdp.ATTR_UPNP_MODEL_NAME]
@ -178,6 +188,8 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
self._username = user_input[CONF_USERNAME] self._username = user_input[CONF_USERNAME]
self._password = user_input[CONF_PASSWORD] self._password = user_input[CONF_PASSWORD]
self._use_tls = user_input[CONF_SSL]
self._port = DEFAULT_HTTPS_PORT if self._use_tls else DEFAULT_HTTP_PORT
error = await self.hass.async_add_executor_job(self.fritz_tools_init) error = await self.hass.async_add_executor_job(self.fritz_tools_init)
@ -191,14 +203,22 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
self, errors: dict[str, str] | None = None self, errors: dict[str, str] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Show the setup form to the user.""" """Show the setup form to the user."""
advanced_data_schema = {}
if self.show_advanced_options:
advanced_data_schema = {
vol.Optional(CONF_PORT): vol.Coerce(int),
}
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
data_schema=vol.Schema( data_schema=vol.Schema(
{ {
vol.Optional(CONF_HOST, default=DEFAULT_HOST): str, vol.Optional(CONF_HOST, default=DEFAULT_HOST): str,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): vol.Coerce(int), **advanced_data_schema,
vol.Required(CONF_USERNAME): str, vol.Required(CONF_USERNAME): str,
vol.Required(CONF_PASSWORD): str, vol.Required(CONF_PASSWORD): str,
vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool,
} }
), ),
errors=errors or {}, errors=errors or {},
@ -214,6 +234,7 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
{ {
vol.Required(CONF_USERNAME): str, vol.Required(CONF_USERNAME): str,
vol.Required(CONF_PASSWORD): str, vol.Required(CONF_PASSWORD): str,
vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool,
} }
), ),
description_placeholders={"name": self._name}, description_placeholders={"name": self._name},
@ -227,9 +248,14 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
if user_input is None: if user_input is None:
return self._show_setup_form_init() return self._show_setup_form_init()
self._host = user_input[CONF_HOST] self._host = user_input[CONF_HOST]
self._port = user_input[CONF_PORT]
self._username = user_input[CONF_USERNAME] self._username = user_input[CONF_USERNAME]
self._password = user_input[CONF_PASSWORD] self._password = user_input[CONF_PASSWORD]
self._use_tls = user_input[CONF_SSL]
if (port := user_input.get(CONF_PORT)) is None:
self._port = DEFAULT_HTTPS_PORT if self._use_tls else DEFAULT_HTTP_PORT
else:
self._port = port
if not (error := await self.hass.async_add_executor_job(self.fritz_tools_init)): if not (error := await self.hass.async_add_executor_job(self.fritz_tools_init)):
self._name = self._model self._name = self._model
@ -251,6 +277,8 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
self._port = entry_data[CONF_PORT] self._port = entry_data[CONF_PORT]
self._username = entry_data[CONF_USERNAME] self._username = entry_data[CONF_USERNAME]
self._password = entry_data[CONF_PASSWORD] self._password = entry_data[CONF_PASSWORD]
self._use_tls = entry_data[CONF_SSL]
return await self.async_step_reauth_confirm() return await self.async_step_reauth_confirm()
def _show_setup_form_reauth_confirm( def _show_setup_form_reauth_confirm(
@ -295,6 +323,7 @@ class FritzBoxToolsFlowHandler(ConfigFlow, domain=DOMAIN):
CONF_PASSWORD: self._password, CONF_PASSWORD: self._password,
CONF_PORT: self._port, CONF_PORT: self._port,
CONF_USERNAME: self._username, CONF_USERNAME: self._username,
CONF_SSL: self._use_tls,
}, },
) )
await self.hass.config_entries.async_reload(self._entry.entry_id) await self.hass.config_entries.async_reload(self._entry.entry_id)

View file

@ -46,8 +46,10 @@ DSL_CONNECTION: Literal["dsl"] = "dsl"
DEFAULT_DEVICE_NAME = "Unknown device" DEFAULT_DEVICE_NAME = "Unknown device"
DEFAULT_HOST = "192.168.178.1" DEFAULT_HOST = "192.168.178.1"
DEFAULT_PORT = 49000 DEFAULT_HTTP_PORT = 49000
DEFAULT_HTTPS_PORT = 49443
DEFAULT_USERNAME = "" DEFAULT_USERNAME = ""
DEFAULT_SSL = False
ERROR_AUTH_INVALID = "invalid_auth" ERROR_AUTH_INVALID = "invalid_auth"
ERROR_CANNOT_CONNECT = "cannot_connect" ERROR_CANNOT_CONNECT = "cannot_connect"

View file

@ -25,10 +25,12 @@
"host": "[%key:common::config_flow::data::host%]", "host": "[%key:common::config_flow::data::host%]",
"port": "[%key:common::config_flow::data::port%]", "port": "[%key:common::config_flow::data::port%]",
"username": "[%key:common::config_flow::data::username%]", "username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]" "password": "[%key:common::config_flow::data::password%]",
"ssl": "[%key:common::config_flow::data::ssl%]"
}, },
"data_description": { "data_description": {
"host": "The hostname or IP address of your FRITZ!Box router." "host": "The hostname or IP address of your FRITZ!Box router.",
"port": "Leave it empty to use the default port."
} }
} }
}, },

View file

@ -74,16 +74,6 @@ class FritzConnectionMock:
return self._services[service][action] return self._services[service][action]
class FritzHostMock(FritzHosts):
"""FritzHosts mocking."""
get_mesh_topology = MagicMock()
get_mesh_topology.return_value = MOCK_MESH_DATA
get_hosts_attributes = MagicMock()
get_hosts_attributes.return_value = MOCK_HOST_ATTRIBUTES_DATA
@pytest.fixture(name="fc_data") @pytest.fixture(name="fc_data")
def fc_data_mock(): def fc_data_mock():
"""Fixture for default fc_data.""" """Fixture for default fc_data."""
@ -105,6 +95,8 @@ def fh_class_mock():
"""Fixture that sets up a mocked FritzHosts class.""" """Fixture that sets up a mocked FritzHosts class."""
with patch( with patch(
"homeassistant.components.fritz.common.FritzHosts", "homeassistant.components.fritz.common.FritzHosts",
new=FritzHostMock, new=FritzHosts,
) as result: ) as result:
result.get_mesh_topology = MagicMock(return_value=MOCK_MESH_DATA)
result.get_hosts_attributes = MagicMock(return_value=MOCK_HOST_ATTRIBUTES_DATA)
yield result yield result

View file

@ -8,6 +8,7 @@ from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PORT, CONF_PORT,
CONF_SSL,
CONF_USERNAME, CONF_USERNAME,
) )
@ -22,10 +23,12 @@ MOCK_CONFIG = {
CONF_PORT: "1234", CONF_PORT: "1234",
CONF_PASSWORD: "fake_pass", CONF_PASSWORD: "fake_pass",
CONF_USERNAME: "fake_user", CONF_USERNAME: "fake_user",
CONF_SSL: False,
} }
] ]
} }
} }
MOCK_HOST = "fake_host" MOCK_HOST = "fake_host"
MOCK_IPS = { MOCK_IPS = {
"fritz.box": "192.168.178.1", "fritz.box": "192.168.178.1",
@ -902,6 +905,14 @@ MOCK_HOST_ATTRIBUTES_DATA = [
] ]
MOCK_USER_DATA = MOCK_CONFIG[DOMAIN][CONF_DEVICES][0] MOCK_USER_DATA = MOCK_CONFIG[DOMAIN][CONF_DEVICES][0]
MOCK_USER_INPUT_ADVANCED = MOCK_USER_DATA
MOCK_USER_INPUT_SIMPLE = {
CONF_HOST: "fake_host",
CONF_PASSWORD: "fake_pass",
CONF_USERNAME: "fake_user",
CONF_SSL: False,
}
MOCK_DEVICE_INFO = { MOCK_DEVICE_INFO = {
ATTR_HOST: MOCK_HOST, ATTR_HOST: MOCK_HOST,
ATTR_NEW_SERIAL_NUMBER: MOCK_SERIAL_NUMBER, ATTR_NEW_SERIAL_NUMBER: MOCK_SERIAL_NUMBER,

View file

@ -24,7 +24,13 @@ from homeassistant.components.fritz.const import (
) )
from homeassistant.components.ssdp import ATTR_UPNP_UDN from homeassistant.components.ssdp import ATTR_UPNP_UDN
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_SSDP, SOURCE_USER from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_SSDP, SOURCE_USER
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
CONF_PORT,
CONF_SSL,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@ -34,12 +40,59 @@ from .const import (
MOCK_REQUEST, MOCK_REQUEST,
MOCK_SSDP_DATA, MOCK_SSDP_DATA,
MOCK_USER_DATA, MOCK_USER_DATA,
MOCK_USER_INPUT_ADVANCED,
MOCK_USER_INPUT_SIMPLE,
) )
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_user(hass: HomeAssistant, fc_class_mock, mock_get_source_ip) -> None: @pytest.mark.parametrize(
("show_advanced_options", "user_input", "expected_config"),
[
(
True,
MOCK_USER_INPUT_ADVANCED,
{
CONF_HOST: "fake_host",
CONF_PASSWORD: "fake_pass",
CONF_USERNAME: "fake_user",
CONF_PORT: 1234,
CONF_SSL: False,
},
),
(
False,
MOCK_USER_INPUT_SIMPLE,
{
CONF_HOST: "fake_host",
CONF_PASSWORD: "fake_pass",
CONF_USERNAME: "fake_user",
CONF_PORT: 49000,
CONF_SSL: False,
},
),
(
False,
{**MOCK_USER_INPUT_SIMPLE, CONF_SSL: True},
{
CONF_HOST: "fake_host",
CONF_PASSWORD: "fake_pass",
CONF_USERNAME: "fake_user",
CONF_PORT: 49443,
CONF_SSL: True,
},
),
],
)
async def test_user(
hass: HomeAssistant,
fc_class_mock,
mock_get_source_ip,
show_advanced_options: bool,
user_input: dict,
expected_config: dict,
) -> None:
"""Test starting a flow by user.""" """Test starting a flow by user."""
with ( with (
patch( patch(
@ -68,18 +121,20 @@ async def test_user(hass: HomeAssistant, fc_class_mock, mock_get_source_ip) -> N
mock_request_post.return_value.text = MOCK_REQUEST mock_request_post.return_value.text = MOCK_REQUEST
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={
"source": SOURCE_USER,
"show_advanced_options": show_advanced_options,
},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_DATA result["flow_id"], user_input=user_input
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"][CONF_HOST] == "fake_host" assert result["data"] == expected_config
assert result["data"][CONF_PASSWORD] == "fake_pass"
assert result["data"][CONF_USERNAME] == "fake_user"
assert ( assert (
result["options"][CONF_CONSIDER_HOME] result["options"][CONF_CONSIDER_HOME]
== DEFAULT_CONSIDER_HOME.total_seconds() == DEFAULT_CONSIDER_HOME.total_seconds()
@ -90,12 +145,20 @@ async def test_user(hass: HomeAssistant, fc_class_mock, mock_get_source_ip) -> N
assert mock_setup_entry.called assert mock_setup_entry.called
@pytest.mark.parametrize(
("show_advanced_options", "user_input"),
[(True, MOCK_USER_INPUT_ADVANCED), (False, MOCK_USER_INPUT_SIMPLE)],
)
async def test_user_already_configured( async def test_user_already_configured(
hass: HomeAssistant, fc_class_mock, mock_get_source_ip hass: HomeAssistant,
fc_class_mock,
mock_get_source_ip,
show_advanced_options: bool,
user_input,
) -> None: ) -> None:
"""Test starting a flow by user with an already configured device.""" """Test starting a flow by user with an already configured device."""
mock_config = MockConfigEntry(domain=DOMAIN, data=MOCK_USER_DATA) mock_config = MockConfigEntry(domain=DOMAIN, data=user_input)
mock_config.add_to_hass(hass) mock_config.add_to_hass(hass)
with ( with (
@ -124,13 +187,17 @@ async def test_user_already_configured(
mock_request_post.return_value.text = MOCK_REQUEST mock_request_post.return_value.text = MOCK_REQUEST
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={
"source": SOURCE_USER,
"show_advanced_options": show_advanced_options,
},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_DATA result["flow_id"], user_input=MOCK_USER_INPUT_SIMPLE
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
@ -141,13 +208,22 @@ async def test_user_already_configured(
"error", "error",
FRITZ_AUTH_EXCEPTIONS, FRITZ_AUTH_EXCEPTIONS,
) )
@pytest.mark.parametrize(
("show_advanced_options", "user_input"),
[(True, MOCK_USER_INPUT_ADVANCED), (False, MOCK_USER_INPUT_SIMPLE)],
)
async def test_exception_security( async def test_exception_security(
hass: HomeAssistant, mock_get_source_ip, error hass: HomeAssistant,
mock_get_source_ip,
error,
show_advanced_options: bool,
user_input,
) -> None: ) -> None:
"""Test starting a flow by user with invalid credentials.""" """Test starting a flow by user with invalid credentials."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={"source": SOURCE_USER, "show_advanced_options": show_advanced_options},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
@ -157,7 +233,7 @@ async def test_exception_security(
side_effect=error, side_effect=error,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_DATA result["flow_id"], user_input=user_input
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
@ -165,11 +241,21 @@ async def test_exception_security(
assert result["errors"]["base"] == ERROR_AUTH_INVALID assert result["errors"]["base"] == ERROR_AUTH_INVALID
async def test_exception_connection(hass: HomeAssistant, mock_get_source_ip) -> None: @pytest.mark.parametrize(
("show_advanced_options", "user_input"),
[(True, MOCK_USER_INPUT_ADVANCED), (False, MOCK_USER_INPUT_SIMPLE)],
)
async def test_exception_connection(
hass: HomeAssistant,
mock_get_source_ip,
show_advanced_options: bool,
user_input,
) -> None:
"""Test starting a flow by user with a connection error.""" """Test starting a flow by user with a connection error."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={"source": SOURCE_USER, "show_advanced_options": show_advanced_options},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
@ -179,7 +265,7 @@ async def test_exception_connection(hass: HomeAssistant, mock_get_source_ip) ->
side_effect=FritzConnectionException, side_effect=FritzConnectionException,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_DATA result["flow_id"], user_input=user_input
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
@ -187,11 +273,18 @@ async def test_exception_connection(hass: HomeAssistant, mock_get_source_ip) ->
assert result["errors"]["base"] == ERROR_CANNOT_CONNECT assert result["errors"]["base"] == ERROR_CANNOT_CONNECT
async def test_exception_unknown(hass: HomeAssistant, mock_get_source_ip) -> None: @pytest.mark.parametrize(
("show_advanced_options", "user_input"),
[(True, MOCK_USER_INPUT_ADVANCED), (False, MOCK_USER_INPUT_SIMPLE)],
)
async def test_exception_unknown(
hass: HomeAssistant, mock_get_source_ip, show_advanced_options: bool, user_input
) -> None:
"""Test starting a flow by user with an unknown exception.""" """Test starting a flow by user with an unknown exception."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={"source": SOURCE_USER, "show_advanced_options": show_advanced_options},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
@ -201,7 +294,7 @@ async def test_exception_unknown(hass: HomeAssistant, mock_get_source_ip) -> Non
side_effect=OSError, side_effect=OSError,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_DATA result["flow_id"], user_input=user_input
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
@ -210,7 +303,9 @@ async def test_exception_unknown(hass: HomeAssistant, mock_get_source_ip) -> Non
async def test_reauth_successful( async def test_reauth_successful(
hass: HomeAssistant, fc_class_mock, mock_get_source_ip hass: HomeAssistant,
fc_class_mock,
mock_get_source_ip,
) -> None: ) -> None:
"""Test starting a reauthentication flow.""" """Test starting a reauthentication flow."""
@ -273,7 +368,11 @@ async def test_reauth_successful(
], ],
) )
async def test_reauth_not_successful( async def test_reauth_not_successful(
hass: HomeAssistant, fc_class_mock, mock_get_source_ip, side_effect, error hass: HomeAssistant,
fc_class_mock,
mock_get_source_ip,
side_effect,
error,
) -> None: ) -> None:
"""Test starting a reauthentication flow but no connection found.""" """Test starting a reauthentication flow but no connection found."""

View file

@ -15,6 +15,8 @@ from tests.common import MockConfigEntry
MOCK_WLANCONFIGS_SAME_SSID: dict[str, dict] = { MOCK_WLANCONFIGS_SAME_SSID: dict[str, dict] = {
"WLANConfiguration1": { "WLANConfiguration1": {
"GetSSID": {"NewSSID": "WiFi"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -34,9 +36,11 @@ MOCK_WLANCONFIGS_SAME_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
"WLANConfiguration2": { "WLANConfiguration2": {
"GetSSID": {"NewSSID": "WiFi"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -56,11 +60,13 @@ MOCK_WLANCONFIGS_SAME_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
} }
MOCK_WLANCONFIGS_DIFF_SSID: dict[str, dict] = { MOCK_WLANCONFIGS_DIFF_SSID: dict[str, dict] = {
"WLANConfiguration1": { "WLANConfiguration1": {
"GetSSID": {"NewSSID": "WiFi"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -80,9 +86,11 @@ MOCK_WLANCONFIGS_DIFF_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
"WLANConfiguration2": { "WLANConfiguration2": {
"GetSSID": {"NewSSID": "WiFi2"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -102,11 +110,13 @@ MOCK_WLANCONFIGS_DIFF_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
} }
MOCK_WLANCONFIGS_DIFF2_SSID: dict[str, dict] = { MOCK_WLANCONFIGS_DIFF2_SSID: dict[str, dict] = {
"WLANConfiguration1": { "WLANConfiguration1": {
"GetSSID": {"NewSSID": "WiFi"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -126,9 +136,11 @@ MOCK_WLANCONFIGS_DIFF2_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
"WLANConfiguration2": { "WLANConfiguration2": {
"GetSSID": {"NewSSID": "WiFi+"},
"GetSecurityKeys": {"NewKeyPassphrase": "mysecret"},
"GetInfo": { "GetInfo": {
"NewEnable": True, "NewEnable": True,
"NewStatus": "Up", "NewStatus": "Up",
@ -148,7 +160,7 @@ MOCK_WLANCONFIGS_DIFF2_SSID: dict[str, dict] = {
"NewMinCharsPSK": 64, "NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64, "NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef", "NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
} },
}, },
} }
@ -179,7 +191,7 @@ async def test_switch_setup(
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done(wait_background_tasks=True)
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
switches = hass.states.async_all(Platform.SWITCH) switches = hass.states.async_all(Platform.SWITCH)