Add Traffic Rule switches to UniFi Network (#118821)
* Add Traffic Rule switches to UniFi Network * Retrieve Fix unifi traffic rule switches Poll for traffic rule updates; have immediate feedback in the UI for modifying traffic rules * Remove default values for unifi entity; Remove unnecessary code * Begin updating traffic rule unit tests * For the mock get request, allow for meta and data properties to not be appended to support v2 api requests Fix traffic rule unit tests; * inspect path to determine json response instead of passing an argument * Remove entity id parameter from tests; remove unused code; rename traffic rule unique ID prefix * Remove parameter with default. * More code removal; * Rename copy/paste variable; remove commented code; remove duplicate default code --------- Co-authored-by: ViViDboarder <ViViDboarder@gmail.com>
This commit is contained in:
parent
be24475cee
commit
18a7d15d14
4 changed files with 164 additions and 8 deletions
|
@ -7,9 +7,10 @@ Make sure expected clients are available for platforms.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable, Coroutine, Sequence
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from aiounifi.interfaces.api_handlers import ItemEvent
|
from aiounifi.interfaces.api_handlers import ItemEvent
|
||||||
|
|
||||||
|
@ -18,6 +19,7 @@ from homeassistant.core import callback
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||||
|
|
||||||
from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS
|
from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS
|
||||||
from ..entity import UnifiEntity, UnifiEntityDescription
|
from ..entity import UnifiEntity, UnifiEntityDescription
|
||||||
|
@ -26,6 +28,7 @@ if TYPE_CHECKING:
|
||||||
from .hub import UnifiHub
|
from .hub import UnifiHub
|
||||||
|
|
||||||
CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1)
|
CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1)
|
||||||
|
POLL_INTERVAL = timedelta(seconds=10)
|
||||||
|
|
||||||
|
|
||||||
class UnifiEntityLoader:
|
class UnifiEntityLoader:
|
||||||
|
@ -43,10 +46,24 @@ class UnifiEntityLoader:
|
||||||
hub.api.port_forwarding.update,
|
hub.api.port_forwarding.update,
|
||||||
hub.api.sites.update,
|
hub.api.sites.update,
|
||||||
hub.api.system_information.update,
|
hub.api.system_information.update,
|
||||||
|
hub.api.traffic_rules.update,
|
||||||
hub.api.wlans.update,
|
hub.api.wlans.update,
|
||||||
)
|
)
|
||||||
|
self.polling_api_updaters = (hub.api.traffic_rules.update,)
|
||||||
self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS]
|
self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS]
|
||||||
|
|
||||||
|
self._dataUpdateCoordinator = DataUpdateCoordinator(
|
||||||
|
hub.hass,
|
||||||
|
LOGGER,
|
||||||
|
name="Unifi entity poller",
|
||||||
|
update_method=self._update_pollable_api_data,
|
||||||
|
update_interval=POLL_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_listener = self._dataUpdateCoordinator.async_add_listener(
|
||||||
|
update_callback=lambda: None
|
||||||
|
)
|
||||||
|
|
||||||
self.platforms: list[
|
self.platforms: list[
|
||||||
tuple[
|
tuple[
|
||||||
AddEntitiesCallback,
|
AddEntitiesCallback,
|
||||||
|
@ -65,16 +82,25 @@ class UnifiEntityLoader:
|
||||||
self._restore_inactive_clients()
|
self._restore_inactive_clients()
|
||||||
self.wireless_clients.update_clients(set(self.hub.api.clients.values()))
|
self.wireless_clients.update_clients(set(self.hub.api.clients.values()))
|
||||||
|
|
||||||
async def _refresh_api_data(self) -> None:
|
async def _refresh_data(
|
||||||
"""Refresh API data from network application."""
|
self, updaters: Sequence[Callable[[], Coroutine[Any, Any, None]]]
|
||||||
|
) -> None:
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[update() for update in self.api_updaters],
|
*[update() for update in updaters],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
for result in results:
|
for result in results:
|
||||||
if result is not None:
|
if result is not None:
|
||||||
LOGGER.warning("Exception on update %s", result)
|
LOGGER.warning("Exception on update %s", result)
|
||||||
|
|
||||||
|
async def _update_pollable_api_data(self) -> None:
|
||||||
|
"""Refresh API data for pollable updaters."""
|
||||||
|
await self._refresh_data(self.polling_api_updaters)
|
||||||
|
|
||||||
|
async def _refresh_api_data(self) -> None:
|
||||||
|
"""Refresh API data from network application."""
|
||||||
|
await self._refresh_data(self.api_updaters)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _restore_inactive_clients(self) -> None:
|
def _restore_inactive_clients(self) -> None:
|
||||||
"""Restore inactive clients.
|
"""Restore inactive clients.
|
||||||
|
|
|
@ -20,6 +20,7 @@ from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups
|
||||||
from aiounifi.interfaces.outlets import Outlets
|
from aiounifi.interfaces.outlets import Outlets
|
||||||
from aiounifi.interfaces.port_forwarding import PortForwarding
|
from aiounifi.interfaces.port_forwarding import PortForwarding
|
||||||
from aiounifi.interfaces.ports import Ports
|
from aiounifi.interfaces.ports import Ports
|
||||||
|
from aiounifi.interfaces.traffic_rules import TrafficRules
|
||||||
from aiounifi.interfaces.wlans import Wlans
|
from aiounifi.interfaces.wlans import Wlans
|
||||||
from aiounifi.models.api import ApiItemT
|
from aiounifi.models.api import ApiItemT
|
||||||
from aiounifi.models.client import Client, ClientBlockRequest
|
from aiounifi.models.client import Client, ClientBlockRequest
|
||||||
|
@ -30,6 +31,7 @@ from aiounifi.models.event import Event, EventKey
|
||||||
from aiounifi.models.outlet import Outlet
|
from aiounifi.models.outlet import Outlet
|
||||||
from aiounifi.models.port import Port
|
from aiounifi.models.port import Port
|
||||||
from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest
|
from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest
|
||||||
|
from aiounifi.models.traffic_rule import TrafficRule, TrafficRuleEnableRequest
|
||||||
from aiounifi.models.wlan import Wlan, WlanEnableRequest
|
from aiounifi.models.wlan import Wlan, WlanEnableRequest
|
||||||
|
|
||||||
from homeassistant.components.switch import (
|
from homeassistant.components.switch import (
|
||||||
|
@ -94,8 +96,8 @@ def async_dpi_group_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_port_forward_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
|
def async_unifi_network_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
|
||||||
"""Create device registry entry for port forward."""
|
"""Create device registry entry for the UniFi Network application."""
|
||||||
unique_id = hub.config.entry.unique_id
|
unique_id = hub.config.entry.unique_id
|
||||||
assert unique_id is not None
|
assert unique_id is not None
|
||||||
return DeviceInfo(
|
return DeviceInfo(
|
||||||
|
@ -158,6 +160,16 @@ async def async_port_forward_control_fn(
|
||||||
await hub.api.request(PortForwardEnableRequest.create(port_forward, target))
|
await hub.api.request(PortForwardEnableRequest.create(port_forward, target))
|
||||||
|
|
||||||
|
|
||||||
|
async def async_traffic_rule_control_fn(
|
||||||
|
hub: UnifiHub, obj_id: str, target: bool
|
||||||
|
) -> None:
|
||||||
|
"""Control traffic rule state."""
|
||||||
|
traffic_rule = hub.api.traffic_rules[obj_id].raw
|
||||||
|
await hub.api.request(TrafficRuleEnableRequest.create(traffic_rule, target))
|
||||||
|
# Update the traffic rules so the UI is updated appropriately
|
||||||
|
await hub.api.traffic_rules.update()
|
||||||
|
|
||||||
|
|
||||||
async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None:
|
async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None:
|
||||||
"""Control outlet relay."""
|
"""Control outlet relay."""
|
||||||
await hub.api.request(WlanEnableRequest.create(obj_id, target))
|
await hub.api.request(WlanEnableRequest.create(obj_id, target))
|
||||||
|
@ -232,12 +244,25 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSwitchEntityDescription, ...] = (
|
||||||
icon="mdi:upload-network",
|
icon="mdi:upload-network",
|
||||||
api_handler_fn=lambda api: api.port_forwarding,
|
api_handler_fn=lambda api: api.port_forwarding,
|
||||||
control_fn=async_port_forward_control_fn,
|
control_fn=async_port_forward_control_fn,
|
||||||
device_info_fn=async_port_forward_device_info_fn,
|
device_info_fn=async_unifi_network_device_info_fn,
|
||||||
is_on_fn=lambda hub, port_forward: port_forward.enabled,
|
is_on_fn=lambda hub, port_forward: port_forward.enabled,
|
||||||
name_fn=lambda port_forward: f"{port_forward.name}",
|
name_fn=lambda port_forward: f"{port_forward.name}",
|
||||||
object_fn=lambda api, obj_id: api.port_forwarding[obj_id],
|
object_fn=lambda api, obj_id: api.port_forwarding[obj_id],
|
||||||
unique_id_fn=lambda hub, obj_id: f"port_forward-{obj_id}",
|
unique_id_fn=lambda hub, obj_id: f"port_forward-{obj_id}",
|
||||||
),
|
),
|
||||||
|
UnifiSwitchEntityDescription[TrafficRules, TrafficRule](
|
||||||
|
key="Traffic rule control",
|
||||||
|
device_class=SwitchDeviceClass.SWITCH,
|
||||||
|
entity_category=EntityCategory.CONFIG,
|
||||||
|
icon="mdi:security-network",
|
||||||
|
api_handler_fn=lambda api: api.traffic_rules,
|
||||||
|
control_fn=async_traffic_rule_control_fn,
|
||||||
|
device_info_fn=async_unifi_network_device_info_fn,
|
||||||
|
is_on_fn=lambda hub, traffic_rule: traffic_rule.enabled,
|
||||||
|
name_fn=lambda traffic_rule: traffic_rule.description,
|
||||||
|
object_fn=lambda api, obj_id: api.traffic_rules[obj_id],
|
||||||
|
unique_id_fn=lambda hub, obj_id: f"traffic_rule-{obj_id}",
|
||||||
|
),
|
||||||
UnifiSwitchEntityDescription[Ports, Port](
|
UnifiSwitchEntityDescription[Ports, Port](
|
||||||
key="PoE port control",
|
key="PoE port control",
|
||||||
device_class=SwitchDeviceClass.OUTLET,
|
device_class=SwitchDeviceClass.OUTLET,
|
||||||
|
|
|
@ -160,6 +160,7 @@ def fixture_request(
|
||||||
dpi_app_payload: list[dict[str, Any]],
|
dpi_app_payload: list[dict[str, Any]],
|
||||||
dpi_group_payload: list[dict[str, Any]],
|
dpi_group_payload: list[dict[str, Any]],
|
||||||
port_forward_payload: list[dict[str, Any]],
|
port_forward_payload: list[dict[str, Any]],
|
||||||
|
traffic_rule_payload: list[dict[str, Any]],
|
||||||
site_payload: list[dict[str, Any]],
|
site_payload: list[dict[str, Any]],
|
||||||
system_information_payload: list[dict[str, Any]],
|
system_information_payload: list[dict[str, Any]],
|
||||||
wlan_payload: list[dict[str, Any]],
|
wlan_payload: list[dict[str, Any]],
|
||||||
|
@ -170,9 +171,16 @@ def fixture_request(
|
||||||
url = f"https://{host}:{DEFAULT_PORT}"
|
url = f"https://{host}:{DEFAULT_PORT}"
|
||||||
|
|
||||||
def mock_get_request(path: str, payload: list[dict[str, Any]]) -> None:
|
def mock_get_request(path: str, payload: list[dict[str, Any]]) -> None:
|
||||||
|
# APIV2 request respoonses have `meta` and `data` automatically appended
|
||||||
|
json = {}
|
||||||
|
if path.startswith("/v2"):
|
||||||
|
json = payload
|
||||||
|
else:
|
||||||
|
json = {"meta": {"rc": "OK"}, "data": payload}
|
||||||
|
|
||||||
aioclient_mock.get(
|
aioclient_mock.get(
|
||||||
f"{url}{path}",
|
f"{url}{path}",
|
||||||
json={"meta": {"rc": "OK"}, "data": payload},
|
json=json,
|
||||||
headers={"content-type": CONTENT_TYPE_JSON},
|
headers={"content-type": CONTENT_TYPE_JSON},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -182,6 +190,7 @@ def fixture_request(
|
||||||
json={"data": "login successful", "meta": {"rc": "ok"}},
|
json={"data": "login successful", "meta": {"rc": "ok"}},
|
||||||
headers={"content-type": CONTENT_TYPE_JSON},
|
headers={"content-type": CONTENT_TYPE_JSON},
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_get_request("/api/self/sites", site_payload)
|
mock_get_request("/api/self/sites", site_payload)
|
||||||
mock_get_request(f"/api/s/{site_id}/stat/sta", client_payload)
|
mock_get_request(f"/api/s/{site_id}/stat/sta", client_payload)
|
||||||
mock_get_request(f"/api/s/{site_id}/rest/user", clients_all_payload)
|
mock_get_request(f"/api/s/{site_id}/rest/user", clients_all_payload)
|
||||||
|
@ -191,6 +200,7 @@ def fixture_request(
|
||||||
mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload)
|
mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload)
|
||||||
mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload)
|
mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload)
|
||||||
mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload)
|
mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload)
|
||||||
|
mock_get_request(f"/v2/api/site/{site_id}/trafficrules", traffic_rule_payload)
|
||||||
|
|
||||||
return __mock_requests
|
return __mock_requests
|
||||||
|
|
||||||
|
@ -262,6 +272,12 @@ def fixture_system_information_data() -> list[dict[str, Any]]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="traffic_rule_payload")
|
||||||
|
def traffic_rule_payload_data() -> list[dict[str, Any]]:
|
||||||
|
"""Traffic rule data."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="wlan_payload")
|
@pytest.fixture(name="wlan_payload")
|
||||||
def fixture_wlan_data() -> list[dict[str, Any]]:
|
def fixture_wlan_data() -> list[dict[str, Any]]:
|
||||||
"""WLAN data."""
|
"""WLAN data."""
|
||||||
|
|
|
@ -774,6 +774,37 @@ PORT_FORWARD_PLEX = {
|
||||||
"src": "any",
|
"src": "any",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TRAFFIC_RULE = {
|
||||||
|
"_id": "6452cd9b859d5b11aa002ea1",
|
||||||
|
"action": "BLOCK",
|
||||||
|
"app_category_ids": [],
|
||||||
|
"app_ids": [],
|
||||||
|
"bandwidth_limit": {
|
||||||
|
"download_limit_kbps": 1024,
|
||||||
|
"enabled": False,
|
||||||
|
"upload_limit_kbps": 1024,
|
||||||
|
},
|
||||||
|
"description": "Test Traffic Rule",
|
||||||
|
"name": "Test Traffic Rule",
|
||||||
|
"domains": [],
|
||||||
|
"enabled": True,
|
||||||
|
"ip_addresses": [],
|
||||||
|
"ip_ranges": [],
|
||||||
|
"matching_target": "INTERNET",
|
||||||
|
"network_ids": [],
|
||||||
|
"regions": [],
|
||||||
|
"schedule": {
|
||||||
|
"date_end": "2023-05-10",
|
||||||
|
"date_start": "2023-05-03",
|
||||||
|
"mode": "ALWAYS",
|
||||||
|
"repeat_on_days": [],
|
||||||
|
"time_all_day": False,
|
||||||
|
"time_range_end": "12:00",
|
||||||
|
"time_range_start": "09:00",
|
||||||
|
},
|
||||||
|
"target_devices": [{"client_mac": CLIENT_1["mac"], "type": "CLIENT"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("client_payload", [[CONTROLLER_HOST]])
|
@pytest.mark.parametrize("client_payload", [[CONTROLLER_HOST]])
|
||||||
@pytest.mark.parametrize("device_payload", [[DEVICE_1]])
|
@pytest.mark.parametrize("device_payload", [[DEVICE_1]])
|
||||||
|
@ -1072,6 +1103,64 @@ async def test_dpi_switches_add_second_app(
|
||||||
assert hass.states.get("switch.block_media_streaming").state == STATE_ON
|
assert hass.states.get("switch.block_media_streaming").state == STATE_ON
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("traffic_rule_payload"), [([TRAFFIC_RULE])])
|
||||||
|
@pytest.mark.usefixtures("config_entry_setup")
|
||||||
|
async def test_traffic_rules(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
mock_websocket_message,
|
||||||
|
config_entry_setup: ConfigEntry,
|
||||||
|
traffic_rule_payload: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Test control of UniFi traffic rules."""
|
||||||
|
|
||||||
|
assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1
|
||||||
|
|
||||||
|
# Validate state object
|
||||||
|
switch_1 = hass.states.get("switch.unifi_network_test_traffic_rule")
|
||||||
|
assert switch_1.state == STATE_ON
|
||||||
|
assert switch_1.attributes.get(ATTR_DEVICE_CLASS) == SwitchDeviceClass.SWITCH
|
||||||
|
|
||||||
|
traffic_rule = deepcopy(traffic_rule_payload[0])
|
||||||
|
|
||||||
|
# Disable traffic rule
|
||||||
|
aioclient_mock.put(
|
||||||
|
f"https://{config_entry_setup.data[CONF_HOST]}:1234"
|
||||||
|
f"/v2/api/site/{config_entry_setup.data[CONF_SITE_ID]}/trafficrules/{traffic_rule['_id']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = aioclient_mock.call_count
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
SWITCH_DOMAIN,
|
||||||
|
"turn_off",
|
||||||
|
{"entity_id": "switch.unifi_network_test_traffic_rule"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
# Updating the value for traffic rules will make another call to retrieve the values
|
||||||
|
assert aioclient_mock.call_count == call_count + 2
|
||||||
|
expected_disable_call = deepcopy(traffic_rule)
|
||||||
|
expected_disable_call["enabled"] = False
|
||||||
|
|
||||||
|
assert aioclient_mock.mock_calls[call_count][2] == expected_disable_call
|
||||||
|
|
||||||
|
call_count = aioclient_mock.call_count
|
||||||
|
|
||||||
|
# Enable traffic rule
|
||||||
|
await hass.services.async_call(
|
||||||
|
SWITCH_DOMAIN,
|
||||||
|
"turn_on",
|
||||||
|
{"entity_id": "switch.unifi_network_test_traffic_rule"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_enable_call = deepcopy(traffic_rule)
|
||||||
|
expected_enable_call["enabled"] = True
|
||||||
|
|
||||||
|
assert aioclient_mock.call_count == call_count + 2
|
||||||
|
assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("device_payload", "entity_id", "outlet_index", "expected_switches"),
|
("device_payload", "entity_id", "outlet_index", "expected_switches"),
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue