Async migration device_tracker (#4406)

* Async migration device_tracker

* change location stuff to async

* address paulus comments

* fix lint & add async discovery listener

* address paulus comments v2

* fix tests

* fix test_mqtt

* fix test_init

* fix gps_acc

* fix lint

* change async_update_stale to callback
This commit is contained in:
Pascal Vizeli 2016-11-18 23:35:08 +01:00 committed by GitHub
parent 265232af98
commit c56f99baaf
7 changed files with 233 additions and 120 deletions

View file

@ -8,13 +8,13 @@ import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
import os import os
import threading
from typing import Any, Sequence, Callable from typing import Any, Sequence, Callable
import voluptuous as vol import voluptuous as vol
from homeassistant.bootstrap import ( from homeassistant.bootstrap import (
prepare_setup_platform, log_exception) async_prepare_setup_platform, async_log_exception)
from homeassistant.core import callback
from homeassistant.components import group, zone from homeassistant.components import group, zone
from homeassistant.components.discovery import SERVICE_NETGEAR from homeassistant.components.discovery import SERVICE_NETGEAR
from homeassistant.config import load_yaml_config_file from homeassistant.config import load_yaml_config_file
@ -28,7 +28,7 @@ from homeassistant.util.async import run_coroutine_threadsafe
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.yaml import dump from homeassistant.util.yaml import dump
from homeassistant.helpers.event import track_utc_time_change from homeassistant.helpers.event import async_track_utc_time_change
from homeassistant.const import ( from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
DEVICE_DEFAULT_NAME, STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_ID) DEVICE_DEFAULT_NAME, STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_ID)
@ -106,14 +106,15 @@ def see(hass: HomeAssistantType, mac: str=None, dev_id: str=None,
hass.services.call(DOMAIN, SERVICE_SEE, data) hass.services.call(DOMAIN, SERVICE_SEE, data)
def setup(hass: HomeAssistantType, config: ConfigType): @asyncio.coroutine
def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Setup device tracker.""" """Setup device tracker."""
yaml_path = hass.config.path(YAML_DEVICES) yaml_path = hass.config.path(YAML_DEVICES)
try: try:
conf = config.get(DOMAIN, []) conf = config.get(DOMAIN, [])
except vol.Invalid as ex: except vol.Invalid as ex:
log_exception(ex, DOMAIN, config, hass) async_log_exception(ex, DOMAIN, config, hass)
return False return False
else: else:
conf = conf[0] if len(conf) > 0 else {} conf = conf[0] if len(conf) > 0 else {}
@ -121,60 +122,77 @@ def setup(hass: HomeAssistantType, config: ConfigType):
timedelta(seconds=DEFAULT_CONSIDER_HOME)) timedelta(seconds=DEFAULT_CONSIDER_HOME))
track_new = conf.get(CONF_TRACK_NEW, DEFAULT_TRACK_NEW) track_new = conf.get(CONF_TRACK_NEW, DEFAULT_TRACK_NEW)
devices = load_config(yaml_path, hass, consider_home) devices = yield from async_load_config(yaml_path, hass, consider_home)
tracker = DeviceTracker(hass, consider_home, track_new, devices) tracker = DeviceTracker(hass, consider_home, track_new, devices)
def setup_platform(p_type, p_config, disc_info=None): # update tracked devices
update_tasks = [device.async_update_ha_state() for device in devices
if device.track]
if update_tasks:
yield from asyncio.wait(update_tasks, loop=hass.loop)
@asyncio.coroutine
def async_setup_platform(p_type, p_config, disc_info=None):
"""Setup a device tracker platform.""" """Setup a device tracker platform."""
platform = prepare_setup_platform(hass, config, DOMAIN, p_type) platform = yield from async_prepare_setup_platform(
hass, config, DOMAIN, p_type)
if platform is None: if platform is None:
return return
try: try:
if hasattr(platform, 'get_scanner'): if hasattr(platform, 'get_scanner'):
scanner = platform.get_scanner(hass, {DOMAIN: p_config}) scanner = yield from hass.loop.run_in_executor(
None, platform.get_scanner, hass, {DOMAIN: p_config})
if scanner is None: if scanner is None:
_LOGGER.error('Error setting up platform %s', p_type) _LOGGER.error('Error setting up platform %s', p_type)
return return
setup_scanner_platform(hass, p_config, scanner, tracker.see) yield from async_setup_scanner_platform(
hass, p_config, scanner, tracker.async_see)
return return
if not platform.setup_scanner(hass, p_config, tracker.see): ret = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
if not ret:
_LOGGER.error('Error setting up platform %s', p_type) _LOGGER.error('Error setting up platform %s', p_type)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error setting up platform %s', p_type) _LOGGER.exception('Error setting up platform %s', p_type)
for p_type, p_config in config_per_platform(config, DOMAIN): setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
setup_platform(p_type, p_config) in config_per_platform(config, DOMAIN)]
if setup_tasks:
yield from asyncio.wait(setup_tasks, loop=hass.loop)
def device_tracker_discovered(service, info): yield from tracker.async_setup_group()
@callback
def async_device_tracker_discovered(service, info):
"""Called when a device tracker platform is discovered.""" """Called when a device tracker platform is discovered."""
setup_platform(DISCOVERY_PLATFORMS[service], {}, info) hass.async_add_job(
async_setup_platform(DISCOVERY_PLATFORMS[service], {}, info))
discovery.listen(hass, DISCOVERY_PLATFORMS.keys(), discovery.async_listen(
device_tracker_discovered) hass, DISCOVERY_PLATFORMS.keys(), async_device_tracker_discovered)
def update_stale(now): # Clean up stale devices
"""Clean up stale devices.""" async_track_utc_time_change(
tracker.update_stale(now) hass, tracker.async_update_stale, second=range(0, 60, 5))
track_utc_time_change(hass, update_stale, second=range(0, 60, 5))
tracker.setup_group() @asyncio.coroutine
def async_see_service(call):
def see_service(call):
"""Service to see a device.""" """Service to see a device."""
args = {key: value for key, value in call.data.items() if key in args = {key: value for key, value in call.data.items() if key in
(ATTR_MAC, ATTR_DEV_ID, ATTR_HOST_NAME, ATTR_LOCATION_NAME, (ATTR_MAC, ATTR_DEV_ID, ATTR_HOST_NAME, ATTR_LOCATION_NAME,
ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)} ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)}
tracker.see(**args) yield from tracker.async_see(**args)
descriptions = load_yaml_config_file( descriptions = yield from hass.loop.run_in_executor(
os.path.join(os.path.dirname(__file__), 'services.yaml')) None, load_yaml_config_file,
hass.services.register(DOMAIN, SERVICE_SEE, see_service, os.path.join(os.path.dirname(__file__), 'services.yaml')
descriptions.get(SERVICE_SEE)) )
hass.services.async_register(
DOMAIN, SERVICE_SEE, async_see_service, descriptions.get(SERVICE_SEE))
return True return True
@ -188,27 +206,35 @@ class DeviceTracker(object):
self.hass = hass self.hass = hass
self.devices = {dev.dev_id: dev for dev in devices} self.devices = {dev.dev_id: dev for dev in devices}
self.mac_to_dev = {dev.mac: dev for dev in devices if dev.mac} self.mac_to_dev = {dev.mac: dev for dev in devices if dev.mac}
self.consider_home = consider_home
self.track_new = track_new
self.group = None # type: group.Group
self._is_updating = asyncio.Lock(loop=hass.loop)
for dev in devices: for dev in devices:
if self.devices[dev.dev_id] is not dev: if self.devices[dev.dev_id] is not dev:
_LOGGER.warning('Duplicate device IDs detected %s', dev.dev_id) _LOGGER.warning('Duplicate device IDs detected %s', dev.dev_id)
if dev.mac and self.mac_to_dev[dev.mac] is not dev: if dev.mac and self.mac_to_dev[dev.mac] is not dev:
_LOGGER.warning('Duplicate device MAC addresses detected %s', _LOGGER.warning('Duplicate device MAC addresses detected %s',
dev.mac) dev.mac)
self.consider_home = consider_home
self.track_new = track_new
self.lock = threading.Lock()
for device in devices:
if device.track:
device.update_ha_state()
self.group = None # type: group.Group
def see(self, mac: str=None, dev_id: str=None, host_name: str=None, def see(self, mac: str=None, dev_id: str=None, host_name: str=None,
location_name: str=None, gps: GPSType=None, gps_accuracy=None, location_name: str=None, gps: GPSType=None, gps_accuracy=None,
battery: str=None, attributes: dict=None): battery: str=None, attributes: dict=None):
"""Notify the device tracker that you see a device.""" """Notify the device tracker that you see a device."""
with self.lock: self.hass.add_job(
self.async_see(mac, dev_id, host_name, location_name, gps,
gps_accuracy, battery, attributes)
)
@asyncio.coroutine
def async_see(self, mac: str=None, dev_id: str=None, host_name: str=None,
location_name: str=None, gps: GPSType=None,
gps_accuracy=None, battery: str=None, attributes: dict=None):
"""Notify the device tracker that you see a device.
This method is a coroutine.
"""
if mac is None and dev_id is None: if mac is None and dev_id is None:
raise HomeAssistantError('Neither mac or device id passed in') raise HomeAssistantError('Neither mac or device id passed in')
elif mac is not None: elif mac is not None:
@ -221,10 +247,10 @@ class DeviceTracker(object):
device = self.devices.get(dev_id) device = self.devices.get(dev_id)
if device: if device:
device.seen(host_name, location_name, gps, gps_accuracy, yield from device.async_seen(host_name, location_name, gps,
battery, attributes) gps_accuracy, battery, attributes)
if device.track: if device.track:
device.update_ha_state() yield from device.async_update_ha_state()
return return
# If no device can be found, create it # If no device can be found, create it
@ -236,46 +262,60 @@ class DeviceTracker(object):
if mac is not None: if mac is not None:
self.mac_to_dev[mac] = device self.mac_to_dev[mac] = device
device.seen(host_name, location_name, gps, gps_accuracy, battery, yield from device.async_seen(host_name, location_name, gps,
attributes) gps_accuracy, battery, attributes)
if device.track: if device.track:
device.update_ha_state() yield from device.async_update_ha_state()
self.hass.bus.fire(EVENT_NEW_DEVICE, { self.hass.bus.async_fire(EVENT_NEW_DEVICE, {
ATTR_ENTITY_ID: device.entity_id, ATTR_ENTITY_ID: device.entity_id,
ATTR_HOST_NAME: device.host_name, ATTR_HOST_NAME: device.host_name,
}) })
# During init, we ignore the group # During init, we ignore the group
if self.group is not None: if self.group is not None:
self.group.update_tracked_entity_ids( yield from self.group.async_update_tracked_entity_ids(
list(self.group.tracking) + [device.entity_id]) list(self.group.tracking) + [device.entity_id])
update_config(self.hass.config.path(YAML_DEVICES), dev_id, device)
def setup_group(self): # update known_devices.yaml
"""Initialize group for all tracked devices.""" self.hass.async_add_job(
run_coroutine_threadsafe( self.async_update_config(self.hass.config.path(YAML_DEVICES),
self.async_setup_group(), self.hass.loop).result() dev_id, device)
)
@asyncio.coroutine
def async_update_config(self, path, dev_id, device):
"""Add device to YAML configuration file.
This method is a coroutine.
"""
with (yield from self._is_updating):
self.hass.loop.run_in_executor(
None, update_config, self.hass.config.path(YAML_DEVICES),
dev_id, device)
@asyncio.coroutine @asyncio.coroutine
def async_setup_group(self): def async_setup_group(self):
"""Initialize group for all tracked devices. """Initialize group for all tracked devices.
This method must be run in the event loop. This method is a coroutine.
""" """
entity_ids = (dev.entity_id for dev in self.devices.values() entity_ids = (dev.entity_id for dev in self.devices.values()
if dev.track) if dev.track)
self.group = yield from group.Group.async_create_group( self.group = yield from group.Group.async_create_group(
self.hass, GROUP_NAME_ALL_DEVICES, entity_ids, False) self.hass, GROUP_NAME_ALL_DEVICES, entity_ids, False)
def update_stale(self, now: dt_util.dt.datetime): @callback
"""Update stale devices.""" def async_update_stale(self, now: dt_util.dt.datetime):
with self.lock: """Update stale devices.
This method must be run in the event loop.
"""
for device in self.devices.values(): for device in self.devices.values():
if (device.track and device.last_update_home and if (device.track and device.last_update_home) and \
device.stale(now)): device.stale(now):
device.update_ha_state(True) self.hass.async_add_job(device.async_update_ha_state(True))
class Device(Entity): class Device(Entity):
@ -362,7 +402,8 @@ class Device(Entity):
"""If device should be hidden.""" """If device should be hidden."""
return self.away_hide and self.state != STATE_HOME return self.away_hide and self.state != STATE_HOME
def seen(self, host_name: str=None, location_name: str=None, @asyncio.coroutine
def async_seen(self, host_name: str=None, location_name: str=None,
gps: GPSType=None, gps_accuracy=0, battery: str=None, gps: GPSType=None, gps_accuracy=0, battery: str=None,
attributes: dict=None): attributes: dict=None):
"""Mark the device as seen.""" """Mark the device as seen."""
@ -373,28 +414,38 @@ class Device(Entity):
self.battery = battery self.battery = battery
self.attributes = attributes self.attributes = attributes
self.gps = None self.gps = None
if gps is not None: if gps is not None:
try: try:
self.gps = float(gps[0]), float(gps[1]) self.gps = float(gps[0]), float(gps[1])
except (ValueError, TypeError, IndexError): except (ValueError, TypeError, IndexError):
_LOGGER.warning('Could not parse gps value for %s: %s', _LOGGER.warning('Could not parse gps value for %s: %s',
self.dev_id, gps) self.dev_id, gps)
self.update()
# pylint: disable=not-an-iterable
yield from self.async_update()
def stale(self, now: dt_util.dt.datetime=None): def stale(self, now: dt_util.dt.datetime=None):
"""Return if device state is stale.""" """Return if device state is stale.
Async friendly.
"""
return self.last_seen and \ return self.last_seen and \
(now or dt_util.utcnow()) - self.last_seen > self.consider_home (now or dt_util.utcnow()) - self.last_seen > self.consider_home
def update(self): @asyncio.coroutine
"""Update state of entity.""" def async_update(self):
"""Update state of entity.
This method is a coroutine.
"""
if not self.last_seen: if not self.last_seen:
return return
elif self.location_name: elif self.location_name:
self._state = self.location_name self._state = self.location_name
elif self.gps is not None: elif self.gps is not None:
zone_state = zone.active_zone(self.hass, self.gps[0], self.gps[1], zone_state = zone.async_active_zone(
self.gps_accuracy) self.hass, self.gps[0], self.gps[1], self.gps_accuracy)
if zone_state is None: if zone_state is None:
self._state = STATE_NOT_HOME self._state = STATE_NOT_HOME
elif zone_state.entity_id == zone.ENTITY_ID_HOME: elif zone_state.entity_id == zone.ENTITY_ID_HOME:
@ -412,6 +463,17 @@ class Device(Entity):
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta): def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
"""Load devices from YAML configuration file.""" """Load devices from YAML configuration file."""
return run_coroutine_threadsafe(
async_load_config(path, hass, consider_home), hass.loop).result()
@asyncio.coroutine
def async_load_config(path: str, hass: HomeAssistantType,
consider_home: timedelta):
"""Load devices from YAML configuration file.
This method is a coroutine.
"""
dev_schema = vol.Schema({ dev_schema = vol.Schema({
vol.Required('name'): cv.string, vol.Required('name'): cv.string,
vol.Optional('track', default=False): cv.boolean, vol.Optional('track', default=False): cv.boolean,
@ -426,7 +488,8 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
try: try:
result = [] result = []
try: try:
devices = load_yaml_config_file(path) devices = yield from hass.loop.run_in_executor(
None, load_yaml_config_file, path)
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err)) _LOGGER.error('Unable to load %s: %s', path, str(err))
return [] return []
@ -436,7 +499,7 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
device = dev_schema(device) device = dev_schema(device)
device['dev_id'] = cv.slugify(dev_id) device['dev_id'] = cv.slugify(dev_id)
except vol.Invalid as exp: except vol.Invalid as exp:
log_exception(exp, dev_id, devices, hass) async_log_exception(exp, dev_id, devices, hass)
else: else:
result.append(Device(hass, **device)) result.append(Device(hass, **device))
return result return result
@ -445,9 +508,13 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
return [] return []
def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType, @asyncio.coroutine
scanner: Any, see_device: Callable): def async_setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
"""Helper method to connect scanner-based platform to device tracker.""" scanner: Any, async_see_device: Callable):
"""Helper method to connect scanner-based platform to device tracker.
This method is a coroutine.
"""
interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
# Initial scan of each mac we also tell about host name for config # Initial scan of each mac we also tell about host name for config
@ -455,18 +522,20 @@ def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
def device_tracker_scan(now: dt_util.dt.datetime): def device_tracker_scan(now: dt_util.dt.datetime):
"""Called when interval matches.""" """Called when interval matches."""
for mac in scanner.scan_devices(): found_devices = scanner.scan_devices()
for mac in found_devices:
if mac in seen: if mac in seen:
host_name = None host_name = None
else: else:
host_name = scanner.get_device_name(mac) host_name = scanner.get_device_name(mac)
seen.add(mac) seen.add(mac)
see_device(mac=mac, host_name=host_name) hass.async_add_job(async_see_device(mac=mac, host_name=host_name))
track_utc_time_change(hass, device_tracker_scan, second=range(0, 60, async_track_utc_time_change(
interval)) hass, device_tracker_scan, second=range(0, 60, interval))
device_tracker_scan(None) hass.async_add_job(device_tracker_scan, None)
def update_config(path: str, dev_id: str, device: Device): def update_config(path: str, dev_id: str, device: Device):
@ -484,7 +553,10 @@ def update_config(path: str, dev_id: str, device: Device):
def get_gravatar_for_email(email: str): def get_gravatar_for_email(email: str):
"""Return an 80px Gravatar for the given email address.""" """Return an 80px Gravatar for the given email address.
Async friendly.
"""
import hashlib import hashlib
url = 'https://www.gravatar.com/avatar/{}.jpg?s=80&d=wavatar' url = 'https://www.gravatar.com/avatar/{}.jpg?s=80&d=wavatar'
return url.format(hashlib.md5(email.encode('utf-8').lower()).hexdigest()) return url.format(hashlib.md5(email.encode('utf-8').lower()).hexdigest())

View file

@ -14,6 +14,7 @@ from homeassistant.const import (
CONF_LONGITUDE, CONF_ICON) CONF_LONGITUDE, CONF_ICON)
from homeassistant.helpers import config_per_platform from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.util.async import run_callback_threadsafe
from homeassistant.util.location import distance from homeassistant.util.location import distance
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -51,9 +52,19 @@ PLATFORM_SCHEMA = vol.Schema({
def active_zone(hass, latitude, longitude, radius=0): def active_zone(hass, latitude, longitude, radius=0):
"""Find the active zone for given latitude, longitude.""" """Find the active zone for given latitude, longitude."""
return run_callback_threadsafe(
hass.loop, async_active_zone, hass, latitude, longitude, radius
).result()
def async_active_zone(hass, latitude, longitude, radius=0):
"""Find the active zone for given latitude, longitude.
This method must be run in the event loop.
"""
# Sort entity IDs so that we are deterministic if equal distance to 2 zones # Sort entity IDs so that we are deterministic if equal distance to 2 zones
zones = (hass.states.get(entity_id) for entity_id zones = (hass.states.get(entity_id) for entity_id
in sorted(hass.states.entity_ids(DOMAIN))) in sorted(hass.states.async_entity_ids(DOMAIN)))
min_dist = None min_dist = None
closest = None closest = None
@ -80,7 +91,10 @@ def active_zone(hass, latitude, longitude, radius=0):
def in_zone(zone, latitude, longitude, radius=0): def in_zone(zone, latitude, longitude, radius=0):
"""Test if given latitude, longitude is in given zone.""" """Test if given latitude, longitude is in given zone.
Async friendly.
"""
zone_dist = distance( zone_dist = distance(
latitude, longitude, latitude, longitude,
zone.attributes[ATTR_LATITUDE], zone.attributes[ATTR_LONGITUDE]) zone.attributes[ATTR_LATITUDE], zone.attributes[ATTR_LONGITUDE])

View file

@ -14,6 +14,16 @@ ATTR_PLATFORM = 'platform'
def listen(hass, service, callback): def listen(hass, service, callback):
"""Setup listener for discovery of specific service. """Setup listener for discovery of specific service.
Service can be a string or a list/tuple.
"""
run_callback_threadsafe(
hass.loop, async_listen, hass, service, callback).result()
@core.callback
def async_listen(hass, service, callback):
"""Setup listener for discovery of specific service.
Service can be a string or a list/tuple. Service can be a string or a list/tuple.
""" """
if isinstance(service, str): if isinstance(service, str):
@ -21,12 +31,14 @@ def listen(hass, service, callback):
else: else:
service = tuple(service) service = tuple(service)
@core.callback
def discovery_event_listener(event): def discovery_event_listener(event):
"""Listen for discovery events.""" """Listen for discovery events."""
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service: if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service:
callback(event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED)) hass.async_add_job(callback, event.data[ATTR_SERVICE],
event.data.get(ATTR_DISCOVERED))
hass.bus.listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener) hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
def discover(hass, service, discovered=None, component=None, hass_config=None): def discover(hass, service, discovered=None, component=None, hass_config=None):

View file

@ -8,7 +8,10 @@ from homeassistant.util import location as loc_util
def has_location(state: State) -> bool: def has_location(state: State) -> bool:
"""Test if state contains a valid location.""" """Test if state contains a valid location.
Async friendly.
"""
return (isinstance(state, State) and return (isinstance(state, State) and
isinstance(state.attributes.get(ATTR_LATITUDE), float) and isinstance(state.attributes.get(ATTR_LATITUDE), float) and
isinstance(state.attributes.get(ATTR_LONGITUDE), float)) isinstance(state.attributes.get(ATTR_LONGITUDE), float))
@ -16,7 +19,10 @@ def has_location(state: State) -> bool:
def closest(latitude: float, longitude: float, def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State: states: Sequence[State]) -> State:
"""Return closest state to point.""" """Return closest state to point.
Async friendly.
"""
with_location = [state for state in states if has_location(state)] with_location = [state for state in states if has_location(state)]
if not with_location: if not with_location:

View file

@ -51,7 +51,10 @@ def detect_location_info():
def distance(lat1, lon1, lat2, lon2): def distance(lat1, lon1, lat2, lon2):
"""Calculate the distance in meters between two points.""" """Calculate the distance in meters between two points.
Async friendly.
"""
return vincenty((lat1, lon1), (lat2, lon2)) * 1000 return vincenty((lat1, lon1), (lat2, lon2)) * 1000
@ -88,6 +91,8 @@ def vincenty(point1: Tuple[float, float], point2: Tuple[float, float],
Result in kilometers or miles between two points on the surface of a Result in kilometers or miles between two points on the surface of a
spheroid. spheroid.
Async friendly.
""" """
# short-circuit coincident points # short-circuit coincident points
if point1[0] == point2[0] and point1[1] == point2[1]: if point1[0] == point2[0] and point1[1] == point2[1]:

View file

@ -10,6 +10,7 @@ import os
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
from homeassistant.loader import get_component from homeassistant.loader import get_component
from homeassistant.util.async import run_coroutine_threadsafe
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN,
@ -280,7 +281,7 @@ class TestComponentsDeviceTracker(unittest.TestCase):
self.assertSequenceEqual((entity_id,), self.assertSequenceEqual((entity_id,),
state.attributes.get(ATTR_ENTITY_ID)) state.attributes.get(ATTR_ENTITY_ID))
@patch('homeassistant.components.device_tracker.DeviceTracker.see') @patch('homeassistant.components.device_tracker.DeviceTracker.async_see')
def test_see_service(self, mock_see): def test_see_service(self, mock_see):
"""Test the see service with a unicode dev_id and NO MAC.""" """Test the see service with a unicode dev_id and NO MAC."""
self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN, self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN,
@ -375,20 +376,22 @@ class TestComponentsDeviceTracker(unittest.TestCase):
# No device id or MAC(not added) # No device id or MAC(not added)
with self.assertRaises(HomeAssistantError): with self.assertRaises(HomeAssistantError):
tracker.see() run_coroutine_threadsafe(
tracker.async_see(), self.hass.loop).result()
assert mock_warning.call_count == 0 assert mock_warning.call_count == 0
# Ignore gps on invalid GPS (both added & warnings) # Ignore gps on invalid GPS (both added & warnings)
tracker.see(mac='mac_1_bad_gps', gps=1) tracker.see(mac='mac_1_bad_gps', gps=1)
tracker.see(mac='mac_2_bad_gps', gps=[1]) tracker.see(mac='mac_2_bad_gps', gps=[1])
tracker.see(mac='mac_3_bad_gps', gps='gps') tracker.see(mac='mac_3_bad_gps', gps='gps')
self.hass.block_till_done()
config = device_tracker.load_config(self.yaml_devices, self.hass, config = device_tracker.load_config(self.yaml_devices, self.hass,
timedelta(seconds=0)) timedelta(seconds=0))
assert mock_warning.call_count == 3 assert mock_warning.call_count == 3
assert len(config) == 4 assert len(config) == 4
@patch('homeassistant.components.device_tracker.log_exception') @patch('homeassistant.components.device_tracker.async_log_exception')
def test_config_failure(self, mock_ex): def test_config_failure(self, mock_ex):
"""Test that the device tracker see failures.""" """Test that the device tracker see failures."""
with assert_setup_component(0, device_tracker.DOMAIN): with assert_setup_component(0, device_tracker.DOMAIN):

View file

@ -37,7 +37,8 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
self.assertTrue('qos' in config) self.assertTrue('qos' in config)
with patch('homeassistant.components.device_tracker.mqtt.' with patch('homeassistant.components.device_tracker.mqtt.'
'setup_scanner', side_effect=mock_setup_scanner) as mock_sp: 'setup_scanner', autospec=True,
side_effect=mock_setup_scanner) as mock_sp:
dev_id = 'paulus' dev_id = 'paulus'
topic = '/location/paulus' topic = '/location/paulus'