From cf2d927f141e195b1c3ca15ca6c5e625994e208f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 20 Aug 2019 10:46:51 -0700 Subject: [PATCH] Use init_subclass for Config Entries (#26059) * Use init_subclass for Config Entries * Ignore type --- homeassistant/components/hue/config_flow.py | 3 +- homeassistant/components/met/config_flow.py | 3 +- homeassistant/config_entries.py | 6 ++++ tests/test_config_entries.py | 35 +++++++++++---------- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/hue/config_flow.py b/homeassistant/components/hue/config_flow.py index 1d058d84b61..0b0e3723b13 100644 --- a/homeassistant/components/hue/config_flow.py +++ b/homeassistant/components/hue/config_flow.py @@ -44,8 +44,7 @@ def _find_username_from_config(hass, filename): return None -@config_entries.HANDLERS.register(DOMAIN) -class HueFlowHandler(config_entries.ConfigFlow): +class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a Hue config flow.""" VERSION = 1 diff --git a/homeassistant/components/met/config_flow.py b/homeassistant/components/met/config_flow.py index e903c717e64..795ba57d988 100644 --- a/homeassistant/components/met/config_flow.py +++ b/homeassistant/components/met/config_flow.py @@ -17,8 +17,7 @@ def configured_instances(hass): ) -@config_entries.HANDLERS.register(DOMAIN) -class MetFlowHandler(config_entries.ConfigFlow): +class MetFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Config flow for Met component.""" VERSION = 1 diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 87bce1a870c..2e1fbea14d1 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -670,6 +670,12 @@ async def _old_conf_migrator(old_config): class ConfigFlow(data_entry_flow.FlowHandler): """Base class for config flows with some helpers.""" + def __init_subclass__(cls, domain=None, **kwargs): + """Initialize a subclass, register if possible.""" + super().__init_subclass__(**kwargs) # type: ignore + if domain is not None: + HANDLERS.register(domain)(cls) + CONNECTION_CLASS = CONN_CLASS_UNKNOWN @staticmethod diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 6c1b00693dd..ca6872a7a2c 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -521,31 +521,32 @@ async def test_discovery_notification(hass): mock_entity_platform(hass, "config_flow.test", None) await async_setup_component(hass, "persistent_notification", {}) - class TestFlow(config_entries.ConfigFlow): - VERSION = 5 + with patch.dict(config_entries.HANDLERS): - async def async_step_discovery(self, user_input=None): - if user_input is not None: - return self.async_create_entry( - title="Test Title", data={"token": "abcd"} - ) - return self.async_show_form(step_id="discovery") + class TestFlow(config_entries.ConfigFlow, domain="test"): + VERSION = 5 + + async def async_step_discovery(self, user_input=None): + if user_input is not None: + return self.async_create_entry( + title="Test Title", data={"token": "abcd"} + ) + return self.async_show_form(step_id="discovery") - with patch.dict(config_entries.HANDLERS, {"test": TestFlow}): result = await hass.config_entries.flow.async_init( "test", context={"source": config_entries.SOURCE_DISCOVERY} ) - await hass.async_block_till_done() - state = hass.states.get("persistent_notification.config_entry_discovery") - assert state is not None + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is not None - result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - await hass.async_block_till_done() - state = hass.states.get("persistent_notification.config_entry_discovery") - assert state is None + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is None async def test_discovery_notification_not_created(hass):