Add support for round-robin DNS (#115218)

* Add support for RR DNS

* 🧪 Update tests for DNS IP round-robin

* 🤖 Configure DNS IP round-robin automatically

* 🐛 Sort IPv6 addresses correctly

* Limit returned IPs and cleanup test class

* 🔟 Change max DNS results to 10

* Rename IPs to ip_addresses
This commit is contained in:
pemontto 2024-05-07 10:49:13 +01:00 committed by GitHub
parent 3d700e2b71
commit 1c414966fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 49 additions and 10 deletions

View file

@ -176,7 +176,10 @@ class DnsIPOptionsFlowHandler(OptionsFlowWithConfigEntry):
else: else:
return self.async_create_entry( return self.async_create_entry(
title=self.config_entry.title, title=self.config_entry.title,
data={CONF_RESOLVER: resolver, CONF_RESOLVER_IPV6: resolver_ipv6}, data={
CONF_RESOLVER: resolver,
CONF_RESOLVER_IPV6: resolver_ipv6,
},
) )
schema = self.add_suggested_values_to_schema( schema = self.add_suggested_values_to_schema(

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from ipaddress import IPv4Address, IPv6Address
import logging import logging
import aiodns import aiodns
@ -25,12 +26,23 @@ from .const import (
) )
DEFAULT_RETRIES = 2 DEFAULT_RETRIES = 2
MAX_RESULTS = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=120) SCAN_INTERVAL = timedelta(seconds=120)
def sort_ips(ips: list, querytype: str) -> list:
"""Join IPs into a single string."""
if querytype == "AAAA":
ips = [IPv6Address(ip) for ip in ips]
else:
ips = [IPv4Address(ip) for ip in ips]
return [str(ip) for ip in sorted(ips)][:MAX_RESULTS]
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None: ) -> None:
@ -41,6 +53,7 @@ async def async_setup_entry(
resolver_ipv4 = entry.options[CONF_RESOLVER] resolver_ipv4 = entry.options[CONF_RESOLVER]
resolver_ipv6 = entry.options[CONF_RESOLVER_IPV6] resolver_ipv6 = entry.options[CONF_RESOLVER_IPV6]
entities = [] entities = []
if entry.data[CONF_IPV4]: if entry.data[CONF_IPV4]:
entities.append(WanIpSensor(name, hostname, resolver_ipv4, False)) entities.append(WanIpSensor(name, hostname, resolver_ipv4, False))
@ -92,7 +105,11 @@ class WanIpSensor(SensorEntity):
response = None response = None
if response: if response:
self._attr_native_value = response[0].host sorted_ips = sort_ips(
[res.host for res in response], querytype=self.querytype
)
self._attr_native_value = sorted_ips[0]
self._attr_extra_state_attributes["ip_addresses"] = sorted_ips
self._attr_available = True self._attr_available = True
self._retries = DEFAULT_RETRIES self._retries = DEFAULT_RETRIES
elif self._retries > 0: elif self._retries > 0:

View file

@ -6,8 +6,10 @@ from __future__ import annotations
class QueryResult: class QueryResult:
"""Return Query results.""" """Return Query results."""
host = "1.2.3.4" def __init__(self, ip="1.2.3.4", ttl=60) -> None:
ttl = 60 """Initialize QueryResult class."""
self.host = ip
self.ttl = ttl
class RetrieveDNS: class RetrieveDNS:
@ -22,11 +24,20 @@ class RetrieveDNS:
self._nameservers = ["1.2.3.4"] self._nameservers = ["1.2.3.4"]
self.error = error self.error = error
async def query(self, hostname, qtype) -> dict[str, str]: async def query(self, hostname, qtype) -> list[QueryResult]:
"""Return information.""" """Return information."""
if self.error: if self.error:
raise self.error raise self.error
return [QueryResult] if qtype == "AAAA":
results = [
QueryResult("2001:db8:77::face:b00c"),
QueryResult("2001:db8:77::dead:beef"),
QueryResult("2001:db8::77:dead:beef"),
QueryResult("2001:db8:66::dead:beef"),
]
else:
results = [QueryResult("1.2.3.4"), QueryResult("1.1.1.1")]
return results
@property @property
def nameservers(self) -> list[str]: def nameservers(self) -> list[str]:

View file

@ -56,8 +56,15 @@ async def test_sensor(hass: HomeAssistant) -> None:
state1 = hass.states.get("sensor.home_assistant_io") state1 = hass.states.get("sensor.home_assistant_io")
state2 = hass.states.get("sensor.home_assistant_io_ipv6") state2 = hass.states.get("sensor.home_assistant_io_ipv6")
assert state1.state == "1.2.3.4" assert state1.state == "1.1.1.1"
assert state2.state == "1.2.3.4" assert state1.attributes["ip_addresses"] == ["1.1.1.1", "1.2.3.4"]
assert state2.state == "2001:db8::77:dead:beef"
assert state2.attributes["ip_addresses"] == [
"2001:db8::77:dead:beef",
"2001:db8:66::dead:beef",
"2001:db8:77::dead:beef",
"2001:db8:77::face:b00c",
]
async def test_sensor_no_response( async def test_sensor_no_response(
@ -92,7 +99,7 @@ async def test_sensor_no_response(
state = hass.states.get("sensor.home_assistant_io") state = hass.states.get("sensor.home_assistant_io")
assert state.state == "1.2.3.4" assert state.state == "1.1.1.1"
dns_mock.error = DNSError() dns_mock.error = DNSError()
with patch( with patch(
@ -107,7 +114,8 @@ async def test_sensor_no_response(
# Allows 2 retries before going unavailable # Allows 2 retries before going unavailable
state = hass.states.get("sensor.home_assistant_io") state = hass.states.get("sensor.home_assistant_io")
assert state.state == "1.2.3.4" assert state.state == "1.1.1.1"
assert state.attributes["ip_addresses"] == ["1.1.1.1", "1.2.3.4"]
freezer.tick(timedelta(seconds=SCAN_INTERVAL.seconds)) freezer.tick(timedelta(seconds=SCAN_INTERVAL.seconds))
async_fire_time_changed(hass) async_fire_time_changed(hass)