Resolve zones and return state in find_coordinates (#66081)

This commit is contained in:
Kevin Stillhammer 2022-02-09 10:43:20 +01:00 committed by GitHub
parent bc9ccf0e47
commit a0119f7ed0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 31 deletions

View file

@ -4,8 +4,6 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import logging import logging
import voluptuous as vol
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant, State
from homeassistant.util import location as loc_util from homeassistant.util import location as loc_util
@ -48,29 +46,42 @@ def closest(latitude: float, longitude: float, states: Iterable[State]) -> State
def find_coordinates( def find_coordinates(
hass: HomeAssistant, entity_id: str, recursion_history: list | None = None hass: HomeAssistant, name: str, recursion_history: list | None = None
) -> str | None: ) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'.""" """Try to resolve the a location from a supplied name or entity_id.
if (entity_state := hass.states.get(entity_id)) is None:
_LOGGER.error("Unable to find entity %s", entity_id)
return None
# Check if the entity has location attributes Will recursively resolve an entity if pointed to by the state of the supplied entity.
Returns coordinates in the form of '90.000,180.000', an address or the state of the last resolved entity.
"""
# Check if a friendly name of a zone was supplied
if (zone_coords := resolve_zone(hass, name)) is not None:
return zone_coords
# Check if an entity_id was supplied.
if (entity_state := hass.states.get(name)) is None:
_LOGGER.debug("Unable to find entity %s", name)
return name
# Check if the entity_state has location attributes
if has_location(entity_state): if has_location(entity_state):
return _get_location_from_attributes(entity_state) return _get_location_from_attributes(entity_state)
# Check if device is in a zone # Check if entity_state is a zone
zone_entity = hass.states.get(f"zone.{entity_state.state}") zone_entity = hass.states.get(f"zone.{entity_state.state}")
if has_location(zone_entity): # type: ignore if has_location(zone_entity): # type: ignore
_LOGGER.debug( _LOGGER.debug(
"%s is in %s, getting zone location", entity_id, zone_entity.entity_id # type: ignore "%s is in %s, getting zone location", name, zone_entity.entity_id # type: ignore
) )
return _get_location_from_attributes(zone_entity) # type: ignore return _get_location_from_attributes(zone_entity) # type: ignore
# Resolve nested entity # Check if entity_state is a friendly name of a zone
if (zone_coords := resolve_zone(hass, entity_state.state)) is not None:
return zone_coords
# Check if entity_state is an entity_id
if recursion_history is None: if recursion_history is None:
recursion_history = [] recursion_history = []
recursion_history.append(entity_id) recursion_history.append(name)
if entity_state.state in recursion_history: if entity_state.state in recursion_history:
_LOGGER.error( _LOGGER.error(
"Circular reference detected while trying to find coordinates of an entity. The state of %s has already been checked", "Circular reference detected while trying to find coordinates of an entity. The state of %s has already been checked",
@ -83,21 +94,18 @@ def find_coordinates(
_LOGGER.debug("Resolving nested entity_id: %s", entity_state.state) _LOGGER.debug("Resolving nested entity_id: %s", entity_state.state)
return find_coordinates(hass, entity_state.state, recursion_history) return find_coordinates(hass, entity_state.state, recursion_history)
# Check if state is valid coordinate set # Might be an address, coordinates or anything else. This has to be checked by the caller.
try: return entity_state.state
# Import here, not at top-level to avoid circular import
from . import config_validation as cv # pylint: disable=import-outside-toplevel
cv.gps(entity_state.state.split(","))
except vol.Invalid: def resolve_zone(hass: HomeAssistant, zone_name: str) -> str | None:
_LOGGER.error( """Get a lat/long from a zones friendly_name or None if no zone is found by that friendly_name."""
"Entity %s does not contain a location and does not point at an entity that does: %s", states = hass.states.async_all("zone")
entity_id, for state in states:
entity_state.state, if state.name == zone_name:
) return _get_location_from_attributes(state)
return None
else: return None
return entity_state.state
def _get_location_from_attributes(entity_state: State) -> str: def _get_location_from_attributes(entity_state: State) -> str:

View file

@ -1,5 +1,5 @@
"""Tests Home Assistant location helpers.""" """Tests Home Assistant location helpers."""
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State from homeassistant.core import State
from homeassistant.helpers import location from homeassistant.helpers import location
@ -73,6 +73,21 @@ async def test_coordinates_function_device_tracker_in_zone(hass):
) )
async def test_coordinates_function_zone_friendly_name(hass):
"""Test coordinates function."""
hass.states.async_set(
"zone.home",
"zoning",
{"latitude": 32.87336, "longitude": -117.22943, ATTR_FRIENDLY_NAME: "my_home"},
)
hass.states.async_set(
"test.object",
"my_home",
)
assert location.find_coordinates(hass, "test.object") == "32.87336,-117.22943"
assert location.find_coordinates(hass, "my_home") == "32.87336,-117.22943"
async def test_coordinates_function_device_tracker_from_input_select(hass): async def test_coordinates_function_device_tracker_from_input_select(hass):
"""Test coordinates function.""" """Test coordinates function."""
hass.states.async_set( hass.states.async_set(
@ -96,15 +111,16 @@ def test_coordinates_function_returns_none_on_recursion(hass):
assert location.find_coordinates(hass, "test.first") is None assert location.find_coordinates(hass, "test.first") is None
async def test_coordinates_function_returns_none_if_invalid_coord(hass): async def test_coordinates_function_returns_state_if_no_coords(hass):
"""Test test_coordinates function.""" """Test test_coordinates function."""
hass.states.async_set( hass.states.async_set(
"test.object", "test.object",
"abc", "abc",
) )
assert location.find_coordinates(hass, "test.object") is None assert location.find_coordinates(hass, "test.object") == "abc"
def test_coordinates_function_returns_none_if_invalid_input(hass): def test_coordinates_function_returns_input_if_no_coords(hass):
"""Test test_coordinates function.""" """Test test_coordinates function."""
assert location.find_coordinates(hass, "test.abc") is None assert location.find_coordinates(hass, "test.abc") == "test.abc"
assert location.find_coordinates(hass, "abc") == "abc"