diff --git a/homeassistant/components/geo_location/geo_json_events.py b/homeassistant/components/geo_location/geo_json_events.py index 7c3f228a4c9..00ac85e6b27 100644 --- a/homeassistant/components/geo_location/geo_json_events.py +++ b/homeassistant/components/geo_location/geo_json_events.py @@ -152,19 +152,23 @@ class GeoJsonLocationEvent(GeoLocationEvent): self._distance = None self._latitude = None self._longitude = None + self._remove_signal_delete = None + self._remove_signal_update = None async def async_added_to_hass(self): """Call when entity is added to hass.""" - async_dispatcher_connect( + self._remove_signal_delete = async_dispatcher_connect( self.hass, SIGNAL_DELETE_ENTITY.format(self._external_id), self._delete_callback) - async_dispatcher_connect( + self._remove_signal_update = async_dispatcher_connect( self.hass, SIGNAL_UPDATE_ENTITY.format(self._external_id), self._update_callback) @callback def _delete_callback(self): """Remove this entity.""" + self._remove_signal_delete() + self._remove_signal_update() self.hass.async_create_task(self.async_remove()) @callback diff --git a/homeassistant/components/geo_location/nsw_rural_fire_service_feed.py b/homeassistant/components/geo_location/nsw_rural_fire_service_feed.py index d3b13abe704..79e0445f494 100644 --- a/homeassistant/components/geo_location/nsw_rural_fire_service_feed.py +++ b/homeassistant/components/geo_location/nsw_rural_fire_service_feed.py @@ -183,19 +183,23 @@ class NswRuralFireServiceLocationEvent(GeoLocationEvent): self._fire = None self._size = None self._responsible_agency = None + self._remove_signal_delete = None + self._remove_signal_update = None async def async_added_to_hass(self): """Call when entity is added to hass.""" - async_dispatcher_connect( + self._remove_signal_delete = async_dispatcher_connect( self.hass, SIGNAL_DELETE_ENTITY.format(self._external_id), self._delete_callback) - async_dispatcher_connect( + self._remove_signal_update = async_dispatcher_connect( self.hass, SIGNAL_UPDATE_ENTITY.format(self._external_id), self._update_callback) @callback def _delete_callback(self): """Remove this entity.""" + self._remove_signal_delete() + self._remove_signal_update() self.hass.async_create_task(self.async_remove()) @callback diff --git a/tests/components/geo_location/test_geo_json_events.py b/tests/components/geo_location/test_geo_json_events.py index dbaf71a6509..00fc9f8c996 100644 --- a/tests/components/geo_location/test_geo_json_events.py +++ b/tests/components/geo_location/test_geo_json_events.py @@ -3,6 +3,7 @@ import unittest from unittest import mock from unittest.mock import patch, MagicMock +import homeassistant from homeassistant.components import geo_location from homeassistant.components.geo_location import ATTR_SOURCE from homeassistant.components.geo_location.geo_json_events import \ @@ -138,3 +139,90 @@ class TestGeoJsonPlatform(unittest.TestCase): all_states = self.hass.states.all() assert len(all_states) == 0 + + @mock.patch('geojson_client.generic_feed.GenericFeed') + def test_setup_race_condition(self, mock_feed): + """Test a particular race condition experienced.""" + # 1. Feed returns 1 entry -> Feed manager creates 1 entity. + # 2. Feed returns error -> Feed manager removes 1 entity. + # However, this stayed on and kept listening for dispatcher signals. + # 3. Feed returns 1 entry -> Feed manager creates 1 entity. + # 4. Feed returns 1 entry -> Feed manager updates 1 entity. + # Internally, the previous entity is updating itself, too. + # 5. Feed returns error -> Feed manager removes 1 entity. + # There are now 2 entities trying to remove themselves from HA, but + # the second attempt fails of course. + + # Set up some mock feed entries for this test. + mock_entry_1 = self._generate_mock_feed_entry('1234', 'Title 1', 15.5, + (-31.0, 150.0)) + mock_feed.return_value.update.return_value = 'OK', [mock_entry_1] + + utcnow = dt_util.utcnow() + # Patching 'utcnow' to gain more control over the timed update. + with patch('homeassistant.util.dt.utcnow', return_value=utcnow): + with assert_setup_component(1, geo_location.DOMAIN): + self.assertTrue(setup_component(self.hass, geo_location.DOMAIN, + CONFIG)) + + # This gives us the ability to assert the '_delete_callback' + # has been called while still executing it. + original_delete_callback = homeassistant.components\ + .geo_location.geo_json_events.GeoJsonLocationEvent\ + ._delete_callback + + def mock_delete_callback(entity): + original_delete_callback(entity) + + with patch('homeassistant.components.geo_location' + '.geo_json_events.GeoJsonLocationEvent' + '._delete_callback', + side_effect=mock_delete_callback, + autospec=True) as mocked_delete_callback: + + # Artificially trigger update. + self.hass.bus.fire(EVENT_HOMEASSISTANT_START) + # Collect events. + self.hass.block_till_done() + + all_states = self.hass.states.all() + assert len(all_states) == 1 + + # Simulate an update - empty data, removes all entities + mock_feed.return_value.update.return_value = 'ERROR', None + fire_time_changed(self.hass, utcnow + SCAN_INTERVAL) + self.hass.block_till_done() + + assert mocked_delete_callback.call_count == 1 + all_states = self.hass.states.all() + assert len(all_states) == 0 + + # Simulate an update - 1 entry + mock_feed.return_value.update.return_value = 'OK', [ + mock_entry_1] + fire_time_changed(self.hass, utcnow + 2 * SCAN_INTERVAL) + self.hass.block_till_done() + + all_states = self.hass.states.all() + assert len(all_states) == 1 + + # Simulate an update - 1 entry + mock_feed.return_value.update.return_value = 'OK', [ + mock_entry_1] + fire_time_changed(self.hass, utcnow + 3 * SCAN_INTERVAL) + self.hass.block_till_done() + + all_states = self.hass.states.all() + assert len(all_states) == 1 + + # Reset mocked method for the next test. + mocked_delete_callback.reset_mock() + + # Simulate an update - empty data, removes all entities + mock_feed.return_value.update.return_value = 'ERROR', None + fire_time_changed(self.hass, utcnow + 4 * SCAN_INTERVAL) + self.hass.block_till_done() + + assert mocked_delete_callback.call_count == 1 + all_states = self.hass.states.all() + assert len(all_states) == 0