Add Traffic Rule switches to UniFi Network (#118821)

* Add Traffic Rule switches to UniFi Network

* Retrieve Fix unifi traffic rule switches

Poll for traffic rule updates; have immediate feedback in the UI for modifying traffic rules

* Remove default values for unifi entity; Remove unnecessary code

* Begin updating traffic rule unit tests

* For the mock get request, allow for meta and data properties to not be appended to support v2 api requests

Fix traffic rule unit tests;

* inspect path to determine json response instead of passing an argument

* Remove entity id parameter from tests; remove unused code; rename traffic rule unique ID prefix

* Remove parameter with default.

* More code removal;

* Rename copy/paste variable; remove commented code; remove duplicate default code

---------

Co-authored-by: ViViDboarder <ViViDboarder@gmail.com>
This commit is contained in:
bdowden 2024-07-30 11:26:08 -04:00 committed by GitHub
parent be24475cee
commit 18a7d15d14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 164 additions and 8 deletions

View file

@ -7,9 +7,10 @@ Make sure expected clients are available for platforms.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Sequence
from datetime import timedelta from datetime import timedelta
from functools import partial from functools import partial
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from aiounifi.interfaces.api_handlers import ItemEvent from aiounifi.interfaces.api_handlers import ItemEvent
@ -18,6 +19,7 @@ from homeassistant.core import callback
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS
from ..entity import UnifiEntity, UnifiEntityDescription from ..entity import UnifiEntity, UnifiEntityDescription
@ -26,6 +28,7 @@ if TYPE_CHECKING:
from .hub import UnifiHub from .hub import UnifiHub
CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1) CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1)
POLL_INTERVAL = timedelta(seconds=10)
class UnifiEntityLoader: class UnifiEntityLoader:
@ -43,10 +46,24 @@ class UnifiEntityLoader:
hub.api.port_forwarding.update, hub.api.port_forwarding.update,
hub.api.sites.update, hub.api.sites.update,
hub.api.system_information.update, hub.api.system_information.update,
hub.api.traffic_rules.update,
hub.api.wlans.update, hub.api.wlans.update,
) )
self.polling_api_updaters = (hub.api.traffic_rules.update,)
self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS] self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS]
self._dataUpdateCoordinator = DataUpdateCoordinator(
hub.hass,
LOGGER,
name="Unifi entity poller",
update_method=self._update_pollable_api_data,
update_interval=POLL_INTERVAL,
)
self._update_listener = self._dataUpdateCoordinator.async_add_listener(
update_callback=lambda: None
)
self.platforms: list[ self.platforms: list[
tuple[ tuple[
AddEntitiesCallback, AddEntitiesCallback,
@ -65,16 +82,25 @@ class UnifiEntityLoader:
self._restore_inactive_clients() self._restore_inactive_clients()
self.wireless_clients.update_clients(set(self.hub.api.clients.values())) self.wireless_clients.update_clients(set(self.hub.api.clients.values()))
async def _refresh_api_data(self) -> None: async def _refresh_data(
"""Refresh API data from network application.""" self, updaters: Sequence[Callable[[], Coroutine[Any, Any, None]]]
) -> None:
results = await asyncio.gather( results = await asyncio.gather(
*[update() for update in self.api_updaters], *[update() for update in updaters],
return_exceptions=True, return_exceptions=True,
) )
for result in results: for result in results:
if result is not None: if result is not None:
LOGGER.warning("Exception on update %s", result) LOGGER.warning("Exception on update %s", result)
async def _update_pollable_api_data(self) -> None:
"""Refresh API data for pollable updaters."""
await self._refresh_data(self.polling_api_updaters)
async def _refresh_api_data(self) -> None:
"""Refresh API data from network application."""
await self._refresh_data(self.api_updaters)
@callback @callback
def _restore_inactive_clients(self) -> None: def _restore_inactive_clients(self) -> None:
"""Restore inactive clients. """Restore inactive clients.

View file

@ -20,6 +20,7 @@ from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups
from aiounifi.interfaces.outlets import Outlets from aiounifi.interfaces.outlets import Outlets
from aiounifi.interfaces.port_forwarding import PortForwarding from aiounifi.interfaces.port_forwarding import PortForwarding
from aiounifi.interfaces.ports import Ports from aiounifi.interfaces.ports import Ports
from aiounifi.interfaces.traffic_rules import TrafficRules
from aiounifi.interfaces.wlans import Wlans from aiounifi.interfaces.wlans import Wlans
from aiounifi.models.api import ApiItemT from aiounifi.models.api import ApiItemT
from aiounifi.models.client import Client, ClientBlockRequest from aiounifi.models.client import Client, ClientBlockRequest
@ -30,6 +31,7 @@ from aiounifi.models.event import Event, EventKey
from aiounifi.models.outlet import Outlet from aiounifi.models.outlet import Outlet
from aiounifi.models.port import Port from aiounifi.models.port import Port
from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest
from aiounifi.models.traffic_rule import TrafficRule, TrafficRuleEnableRequest
from aiounifi.models.wlan import Wlan, WlanEnableRequest from aiounifi.models.wlan import Wlan, WlanEnableRequest
from homeassistant.components.switch import ( from homeassistant.components.switch import (
@ -94,8 +96,8 @@ def async_dpi_group_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
@callback @callback
def async_port_forward_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo: def async_unifi_network_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
"""Create device registry entry for port forward.""" """Create device registry entry for the UniFi Network application."""
unique_id = hub.config.entry.unique_id unique_id = hub.config.entry.unique_id
assert unique_id is not None assert unique_id is not None
return DeviceInfo( return DeviceInfo(
@ -158,6 +160,16 @@ async def async_port_forward_control_fn(
await hub.api.request(PortForwardEnableRequest.create(port_forward, target)) await hub.api.request(PortForwardEnableRequest.create(port_forward, target))
async def async_traffic_rule_control_fn(
hub: UnifiHub, obj_id: str, target: bool
) -> None:
"""Control traffic rule state."""
traffic_rule = hub.api.traffic_rules[obj_id].raw
await hub.api.request(TrafficRuleEnableRequest.create(traffic_rule, target))
# Update the traffic rules so the UI is updated appropriately
await hub.api.traffic_rules.update()
async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None: async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None:
"""Control outlet relay.""" """Control outlet relay."""
await hub.api.request(WlanEnableRequest.create(obj_id, target)) await hub.api.request(WlanEnableRequest.create(obj_id, target))
@ -232,12 +244,25 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSwitchEntityDescription, ...] = (
icon="mdi:upload-network", icon="mdi:upload-network",
api_handler_fn=lambda api: api.port_forwarding, api_handler_fn=lambda api: api.port_forwarding,
control_fn=async_port_forward_control_fn, control_fn=async_port_forward_control_fn,
device_info_fn=async_port_forward_device_info_fn, device_info_fn=async_unifi_network_device_info_fn,
is_on_fn=lambda hub, port_forward: port_forward.enabled, is_on_fn=lambda hub, port_forward: port_forward.enabled,
name_fn=lambda port_forward: f"{port_forward.name}", name_fn=lambda port_forward: f"{port_forward.name}",
object_fn=lambda api, obj_id: api.port_forwarding[obj_id], object_fn=lambda api, obj_id: api.port_forwarding[obj_id],
unique_id_fn=lambda hub, obj_id: f"port_forward-{obj_id}", unique_id_fn=lambda hub, obj_id: f"port_forward-{obj_id}",
), ),
UnifiSwitchEntityDescription[TrafficRules, TrafficRule](
key="Traffic rule control",
device_class=SwitchDeviceClass.SWITCH,
entity_category=EntityCategory.CONFIG,
icon="mdi:security-network",
api_handler_fn=lambda api: api.traffic_rules,
control_fn=async_traffic_rule_control_fn,
device_info_fn=async_unifi_network_device_info_fn,
is_on_fn=lambda hub, traffic_rule: traffic_rule.enabled,
name_fn=lambda traffic_rule: traffic_rule.description,
object_fn=lambda api, obj_id: api.traffic_rules[obj_id],
unique_id_fn=lambda hub, obj_id: f"traffic_rule-{obj_id}",
),
UnifiSwitchEntityDescription[Ports, Port]( UnifiSwitchEntityDescription[Ports, Port](
key="PoE port control", key="PoE port control",
device_class=SwitchDeviceClass.OUTLET, device_class=SwitchDeviceClass.OUTLET,

View file

@ -160,6 +160,7 @@ def fixture_request(
dpi_app_payload: list[dict[str, Any]], dpi_app_payload: list[dict[str, Any]],
dpi_group_payload: list[dict[str, Any]], dpi_group_payload: list[dict[str, Any]],
port_forward_payload: list[dict[str, Any]], port_forward_payload: list[dict[str, Any]],
traffic_rule_payload: list[dict[str, Any]],
site_payload: list[dict[str, Any]], site_payload: list[dict[str, Any]],
system_information_payload: list[dict[str, Any]], system_information_payload: list[dict[str, Any]],
wlan_payload: list[dict[str, Any]], wlan_payload: list[dict[str, Any]],
@ -170,9 +171,16 @@ def fixture_request(
url = f"https://{host}:{DEFAULT_PORT}" url = f"https://{host}:{DEFAULT_PORT}"
def mock_get_request(path: str, payload: list[dict[str, Any]]) -> None: def mock_get_request(path: str, payload: list[dict[str, Any]]) -> None:
# APIV2 request respoonses have `meta` and `data` automatically appended
json = {}
if path.startswith("/v2"):
json = payload
else:
json = {"meta": {"rc": "OK"}, "data": payload}
aioclient_mock.get( aioclient_mock.get(
f"{url}{path}", f"{url}{path}",
json={"meta": {"rc": "OK"}, "data": payload}, json=json,
headers={"content-type": CONTENT_TYPE_JSON}, headers={"content-type": CONTENT_TYPE_JSON},
) )
@ -182,6 +190,7 @@ def fixture_request(
json={"data": "login successful", "meta": {"rc": "ok"}}, json={"data": "login successful", "meta": {"rc": "ok"}},
headers={"content-type": CONTENT_TYPE_JSON}, headers={"content-type": CONTENT_TYPE_JSON},
) )
mock_get_request("/api/self/sites", site_payload) mock_get_request("/api/self/sites", site_payload)
mock_get_request(f"/api/s/{site_id}/stat/sta", client_payload) mock_get_request(f"/api/s/{site_id}/stat/sta", client_payload)
mock_get_request(f"/api/s/{site_id}/rest/user", clients_all_payload) mock_get_request(f"/api/s/{site_id}/rest/user", clients_all_payload)
@ -191,6 +200,7 @@ def fixture_request(
mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload) mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload)
mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload) mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload)
mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload) mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload)
mock_get_request(f"/v2/api/site/{site_id}/trafficrules", traffic_rule_payload)
return __mock_requests return __mock_requests
@ -262,6 +272,12 @@ def fixture_system_information_data() -> list[dict[str, Any]]:
] ]
@pytest.fixture(name="traffic_rule_payload")
def traffic_rule_payload_data() -> list[dict[str, Any]]:
"""Traffic rule data."""
return []
@pytest.fixture(name="wlan_payload") @pytest.fixture(name="wlan_payload")
def fixture_wlan_data() -> list[dict[str, Any]]: def fixture_wlan_data() -> list[dict[str, Any]]:
"""WLAN data.""" """WLAN data."""

View file

@ -774,6 +774,37 @@ PORT_FORWARD_PLEX = {
"src": "any", "src": "any",
} }
TRAFFIC_RULE = {
"_id": "6452cd9b859d5b11aa002ea1",
"action": "BLOCK",
"app_category_ids": [],
"app_ids": [],
"bandwidth_limit": {
"download_limit_kbps": 1024,
"enabled": False,
"upload_limit_kbps": 1024,
},
"description": "Test Traffic Rule",
"name": "Test Traffic Rule",
"domains": [],
"enabled": True,
"ip_addresses": [],
"ip_ranges": [],
"matching_target": "INTERNET",
"network_ids": [],
"regions": [],
"schedule": {
"date_end": "2023-05-10",
"date_start": "2023-05-03",
"mode": "ALWAYS",
"repeat_on_days": [],
"time_all_day": False,
"time_range_end": "12:00",
"time_range_start": "09:00",
},
"target_devices": [{"client_mac": CLIENT_1["mac"], "type": "CLIENT"}],
}
@pytest.mark.parametrize("client_payload", [[CONTROLLER_HOST]]) @pytest.mark.parametrize("client_payload", [[CONTROLLER_HOST]])
@pytest.mark.parametrize("device_payload", [[DEVICE_1]]) @pytest.mark.parametrize("device_payload", [[DEVICE_1]])
@ -1072,6 +1103,64 @@ async def test_dpi_switches_add_second_app(
assert hass.states.get("switch.block_media_streaming").state == STATE_ON assert hass.states.get("switch.block_media_streaming").state == STATE_ON
@pytest.mark.parametrize(("traffic_rule_payload"), [([TRAFFIC_RULE])])
@pytest.mark.usefixtures("config_entry_setup")
async def test_traffic_rules(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
mock_websocket_message,
config_entry_setup: ConfigEntry,
traffic_rule_payload: list[dict[str, Any]],
) -> None:
"""Test control of UniFi traffic rules."""
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1
# Validate state object
switch_1 = hass.states.get("switch.unifi_network_test_traffic_rule")
assert switch_1.state == STATE_ON
assert switch_1.attributes.get(ATTR_DEVICE_CLASS) == SwitchDeviceClass.SWITCH
traffic_rule = deepcopy(traffic_rule_payload[0])
# Disable traffic rule
aioclient_mock.put(
f"https://{config_entry_setup.data[CONF_HOST]}:1234"
f"/v2/api/site/{config_entry_setup.data[CONF_SITE_ID]}/trafficrules/{traffic_rule['_id']}",
)
call_count = aioclient_mock.call_count
await hass.services.async_call(
SWITCH_DOMAIN,
"turn_off",
{"entity_id": "switch.unifi_network_test_traffic_rule"},
blocking=True,
)
# Updating the value for traffic rules will make another call to retrieve the values
assert aioclient_mock.call_count == call_count + 2
expected_disable_call = deepcopy(traffic_rule)
expected_disable_call["enabled"] = False
assert aioclient_mock.mock_calls[call_count][2] == expected_disable_call
call_count = aioclient_mock.call_count
# Enable traffic rule
await hass.services.async_call(
SWITCH_DOMAIN,
"turn_on",
{"entity_id": "switch.unifi_network_test_traffic_rule"},
blocking=True,
)
expected_enable_call = deepcopy(traffic_rule)
expected_enable_call["enabled"] = True
assert aioclient_mock.call_count == call_count + 2
assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call
@pytest.mark.parametrize( @pytest.mark.parametrize(
("device_payload", "entity_id", "outlet_index", "expected_switches"), ("device_payload", "entity_id", "outlet_index", "expected_switches"),
[ [