From 74837dbf450039854712c486bea33a4636aad0dd Mon Sep 17 00:00:00 2001 From: Johann Kellerman Date: Wed, 22 Feb 2017 22:55:11 +0200 Subject: [PATCH] Restore for device_tracker (#6150) --- .../components/device_tracker/__init__.py | 27 +++++++++++++++ tests/components/device_tracker/test_init.py | 34 +++++++++++++++++-- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/device_tracker/__init__.py b/homeassistant/components/device_tracker/__init__.py index d4a80358f02..5aa9765d983 100644 --- a/homeassistant/components/device_tracker/__init__.py +++ b/homeassistant/components/device_tracker/__init__.py @@ -25,6 +25,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.restore_state import async_get_last_state from homeassistant.helpers.typing import GPSType, ConfigType, HomeAssistantType import homeassistant.helpers.config_validation as cv import homeassistant.util as util @@ -132,6 +133,12 @@ def async_setup(hass: HomeAssistantType, config: ConfigType): devices = yield from async_load_config(yaml_path, hass, consider_home) tracker = DeviceTracker(hass, consider_home, track_new, devices) + # added_to_hass + add_tasks = [device.async_added_to_hass() for device in devices + if device.track] + if add_tasks: + yield from asyncio.wait(add_tasks, loop=hass.loop) + # update tracked devices update_tasks = [device.async_update_ha_state() for device in devices if device.track] @@ -561,6 +568,26 @@ class Device(Entity): if resp is not None: yield from resp.release() + @asyncio.coroutine + def async_added_to_hass(self): + """Called when entity about to be added to hass.""" + state = yield from async_get_last_state(self.hass, self.entity_id) + if not state: + return + self._state = state.state + + for attr, var in ( + (ATTR_SOURCE_TYPE, 'source_type'), + (ATTR_GPS_ACCURACY, 'gps_accuracy'), + (ATTR_BATTERY, 'battery'), + ): + if attr in state.attributes: + setattr(self, var, state.attributes[attr]) + + if ATTR_LONGITUDE in state.attributes: + self.gps = (state.attributes[ATTR_LATITUDE], + state.attributes[ATTR_LONGITUDE]) + class DeviceScanner(object): """Device scanner object.""" diff --git a/tests/components/device_tracker/test_init.py b/tests/components/device_tracker/test_init.py index 4f932cd177f..c12d984d275 100644 --- a/tests/components/device_tracker/test_init.py +++ b/tests/components/device_tracker/test_init.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta import os from homeassistant.components import zone -from homeassistant.core import callback +from homeassistant.core import callback, State from homeassistant.bootstrap import setup_component from homeassistant.helpers import discovery from homeassistant.loader import get_component @@ -24,7 +24,7 @@ from homeassistant.remote import JSONEncoder from tests.common import ( get_test_home_assistant, fire_time_changed, fire_service_discovered, - patch_yaml_files, assert_setup_component) + patch_yaml_files, assert_setup_component, mock_restore_cache) from ...test_util.aiohttp import mock_aiohttp_client @@ -656,3 +656,33 @@ class TestComponentsDeviceTracker(unittest.TestCase): setup_component(self.hass, device_tracker.DOMAIN, {device_tracker.DOMAIN: { device_tracker.CONF_CONSIDER_HOME: -1}}) + + +@asyncio.coroutine +def test_async_added_to_hass(hass): + """Test resoring state.""" + attr = { + device_tracker.ATTR_LONGITUDE: 18, + device_tracker.ATTR_LATITUDE: -33, + device_tracker.ATTR_LATITUDE: -33, + device_tracker.ATTR_SOURCE_TYPE: 'gps', + device_tracker.ATTR_GPS_ACCURACY: 2, + device_tracker.ATTR_BATTERY: 100 + } + mock_restore_cache(hass, [State('device_tracker.jk', 'home', attr)]) + + path = hass.config.path(device_tracker.YAML_DEVICES) + + files = { + path: 'jk:\n name: JK Phone\n track: True', + } + with patch_yaml_files(files): + yield from device_tracker.async_setup(hass, {}) + + state = hass.states.get('device_tracker.jk') + assert state + assert state.state == 'home' + + for key, val in attr.items(): + atr = state.attributes.get(key) + assert atr == val, "{}={} expected: {}".format(key, atr, val)