diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 46bb2f7bfe2..b159f01c72b 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -141,6 +141,9 @@ ENTRY_STATE_SETUP_ERROR = 'setup_error' ENTRY_STATE_NOT_LOADED = 'not_loaded' ENTRY_STATE_FAILED_UNLOAD = 'failed_unload' +DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery' +DISCOVERY_SOURCES = (data_entry_flow.SOURCE_DISCOVERY,) + class ConfigEntry: """Hold a configuration entry.""" @@ -362,9 +365,19 @@ class ConfigEntries: await async_setup_component( self.hass, entry.domain, self._hass_config) + # Return Entry if they not from a discovery request + if result['source'] not in DISCOVERY_SOURCES: + return entry + + # If no discovery config entries in progress, remove notification. + if not any(ent['source'] in DISCOVERY_SOURCES for ent + in self.hass.config_entries.flow.async_progress()): + self.hass.components.persistent_notification.async_dismiss( + DISCOVERY_NOTIFICATION_ID) + return entry - async def _async_create_flow(self, handler): + async def _async_create_flow(self, handler, *, source, data): """Create a flow for specified handler. Handler key is the domain of the component that we want to setup. @@ -379,6 +392,15 @@ class ConfigEntries: await async_process_deps_reqs( self.hass, self._hass_config, handler, component) + # Create notification. + if source in DISCOVERY_SOURCES: + self.hass.components.persistent_notification.async_create( + title='New devices discovered', + message=("We have discovered new devices on your network. " + "[Check it out](/config/integrations)"), + notification_id=DISCOVERY_NOTIFICATION_ID + ) + return handler() @callback diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index cadec3f3d69..8eb18a3a7e7 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -52,7 +52,7 @@ class FlowManager: async def async_init(self, handler, *, source=SOURCE_USER, data=None): """Start a configuration flow.""" - flow = await self._async_create_flow(handler) + flow = await self._async_create_flow(handler, source=source, data=data) flow.hass = self.hass flow.handler = handler flow.flow_id = uuid.uuid4().hex diff --git a/tests/conftest.py b/tests/conftest.py index 269d460ebb6..73e69605eae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ if os.environ.get('UVLOOP') == '1': import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -logging.basicConfig() +logging.basicConfig(level=logging.INFO) logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 94b1dcb47da..b46909d7732 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -245,3 +245,39 @@ async def test_forward_entry_does_not_setup_entry_if_setup_fails(hass): await hass.config_entries.async_forward_entry_setup(entry, 'forwarded') assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 0 + + +async def test_discovery_notification(hass): + """Test that we create/dismiss a notification when source is discovery.""" + await async_setup_component(hass, 'persistent_notification', {}) + + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + async def async_step_discovery(self, user_input=None): + if user_input is not None: + return self.async_create_entry( + title='Test Title', + data={ + 'token': 'abcd' + } + ) + return self.async_show_form( + step_id='discovery', + ) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + result = await hass.config_entries.flow.async_init( + 'test', source=data_entry_flow.SOURCE_DISCOVERY) + + await hass.async_block_till_done() + state = hass.states.get('persistent_notification.config_entry_discovery') + assert state is not None + + result = await hass.config_entries.flow.async_configure( + result['flow_id'], {}) + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + + await hass.async_block_till_done() + state = hass.states.get('persistent_notification.config_entry_discovery') + assert state is None diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 2767e206c30..6d3e41436c5 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -12,7 +12,7 @@ def manager(): handlers = Registry() entries = [] - async def async_create_flow(handler_name): + async def async_create_flow(handler_name, *, source, data): handler = handlers.get(handler_name) if handler is None: