diff --git a/homeassistant/components/unifi/hub/entity_loader.py b/homeassistant/components/unifi/hub/entity_loader.py index 29448a4114a..f11ddefec98 100644 --- a/homeassistant/components/unifi/hub/entity_loader.py +++ b/homeassistant/components/unifi/hub/entity_loader.py @@ -7,9 +7,10 @@ Make sure expected clients are available for platforms. from __future__ import annotations import asyncio +from collections.abc import Callable, Coroutine, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from aiounifi.interfaces.api_handlers import ItemEvent @@ -18,6 +19,7 @@ from homeassistant.core import callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS from ..entity import UnifiEntity, UnifiEntityDescription @@ -26,6 +28,7 @@ if TYPE_CHECKING: from .hub import UnifiHub CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1) +POLL_INTERVAL = timedelta(seconds=10) class UnifiEntityLoader: @@ -43,10 +46,24 @@ class UnifiEntityLoader: hub.api.port_forwarding.update, hub.api.sites.update, hub.api.system_information.update, + hub.api.traffic_rules.update, hub.api.wlans.update, ) + self.polling_api_updaters = (hub.api.traffic_rules.update,) self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS] + self._dataUpdateCoordinator = DataUpdateCoordinator( + hub.hass, + LOGGER, + name="Unifi entity poller", + update_method=self._update_pollable_api_data, + update_interval=POLL_INTERVAL, + ) + + self._update_listener = self._dataUpdateCoordinator.async_add_listener( + update_callback=lambda: None + ) + self.platforms: list[ tuple[ AddEntitiesCallback, @@ -65,16 +82,25 @@ class UnifiEntityLoader: self._restore_inactive_clients() self.wireless_clients.update_clients(set(self.hub.api.clients.values())) - async def _refresh_api_data(self) -> None: - """Refresh API data from network application.""" + async def _refresh_data( + self, updaters: Sequence[Callable[[], Coroutine[Any, Any, None]]] + ) -> None: results = await asyncio.gather( - *[update() for update in self.api_updaters], + *[update() for update in updaters], return_exceptions=True, ) for result in results: if result is not None: LOGGER.warning("Exception on update %s", result) + async def _update_pollable_api_data(self) -> None: + """Refresh API data for pollable updaters.""" + await self._refresh_data(self.polling_api_updaters) + + async def _refresh_api_data(self) -> None: + """Refresh API data from network application.""" + await self._refresh_data(self.api_updaters) + @callback def _restore_inactive_clients(self) -> None: """Restore inactive clients. diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index ef30abb9349..93a0c81a24e 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -20,6 +20,7 @@ from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups from aiounifi.interfaces.outlets import Outlets from aiounifi.interfaces.port_forwarding import PortForwarding from aiounifi.interfaces.ports import Ports +from aiounifi.interfaces.traffic_rules import TrafficRules from aiounifi.interfaces.wlans import Wlans from aiounifi.models.api import ApiItemT from aiounifi.models.client import Client, ClientBlockRequest @@ -30,6 +31,7 @@ from aiounifi.models.event import Event, EventKey from aiounifi.models.outlet import Outlet from aiounifi.models.port import Port from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest +from aiounifi.models.traffic_rule import TrafficRule, TrafficRuleEnableRequest from aiounifi.models.wlan import Wlan, WlanEnableRequest from homeassistant.components.switch import ( @@ -94,8 +96,8 @@ def async_dpi_group_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo: @callback -def async_port_forward_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo: - """Create device registry entry for port forward.""" +def async_unifi_network_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo: + """Create device registry entry for the UniFi Network application.""" unique_id = hub.config.entry.unique_id assert unique_id is not None return DeviceInfo( @@ -158,6 +160,16 @@ async def async_port_forward_control_fn( await hub.api.request(PortForwardEnableRequest.create(port_forward, target)) +async def async_traffic_rule_control_fn( + hub: UnifiHub, obj_id: str, target: bool +) -> None: + """Control traffic rule state.""" + traffic_rule = hub.api.traffic_rules[obj_id].raw + await hub.api.request(TrafficRuleEnableRequest.create(traffic_rule, target)) + # Update the traffic rules so the UI is updated appropriately + await hub.api.traffic_rules.update() + + async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None: """Control outlet relay.""" await hub.api.request(WlanEnableRequest.create(obj_id, target)) @@ -232,12 +244,25 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSwitchEntityDescription, ...] = ( icon="mdi:upload-network", api_handler_fn=lambda api: api.port_forwarding, control_fn=async_port_forward_control_fn, - device_info_fn=async_port_forward_device_info_fn, + device_info_fn=async_unifi_network_device_info_fn, is_on_fn=lambda hub, port_forward: port_forward.enabled, name_fn=lambda port_forward: f"{port_forward.name}", object_fn=lambda api, obj_id: api.port_forwarding[obj_id], unique_id_fn=lambda hub, obj_id: f"port_forward-{obj_id}", ), + UnifiSwitchEntityDescription[TrafficRules, TrafficRule]( + key="Traffic rule control", + device_class=SwitchDeviceClass.SWITCH, + entity_category=EntityCategory.CONFIG, + icon="mdi:security-network", + api_handler_fn=lambda api: api.traffic_rules, + control_fn=async_traffic_rule_control_fn, + device_info_fn=async_unifi_network_device_info_fn, + is_on_fn=lambda hub, traffic_rule: traffic_rule.enabled, + name_fn=lambda traffic_rule: traffic_rule.description, + object_fn=lambda api, obj_id: api.traffic_rules[obj_id], + unique_id_fn=lambda hub, obj_id: f"traffic_rule-{obj_id}", + ), UnifiSwitchEntityDescription[Ports, Port]( key="PoE port control", device_class=SwitchDeviceClass.OUTLET, diff --git a/tests/components/unifi/conftest.py b/tests/components/unifi/conftest.py index c20b8766bfc..4e460bab8f8 100644 --- a/tests/components/unifi/conftest.py +++ b/tests/components/unifi/conftest.py @@ -160,6 +160,7 @@ def fixture_request( dpi_app_payload: list[dict[str, Any]], dpi_group_payload: list[dict[str, Any]], port_forward_payload: list[dict[str, Any]], + traffic_rule_payload: list[dict[str, Any]], site_payload: list[dict[str, Any]], system_information_payload: list[dict[str, Any]], wlan_payload: list[dict[str, Any]], @@ -170,9 +171,16 @@ def fixture_request( url = f"https://{host}:{DEFAULT_PORT}" def mock_get_request(path: str, payload: list[dict[str, Any]]) -> None: + # APIV2 request respoonses have `meta` and `data` automatically appended + json = {} + if path.startswith("/v2"): + json = payload + else: + json = {"meta": {"rc": "OK"}, "data": payload} + aioclient_mock.get( f"{url}{path}", - json={"meta": {"rc": "OK"}, "data": payload}, + json=json, headers={"content-type": CONTENT_TYPE_JSON}, ) @@ -182,6 +190,7 @@ def fixture_request( json={"data": "login successful", "meta": {"rc": "ok"}}, headers={"content-type": CONTENT_TYPE_JSON}, ) + mock_get_request("/api/self/sites", site_payload) mock_get_request(f"/api/s/{site_id}/stat/sta", client_payload) mock_get_request(f"/api/s/{site_id}/rest/user", clients_all_payload) @@ -191,6 +200,7 @@ def fixture_request( mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload) mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload) mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload) + mock_get_request(f"/v2/api/site/{site_id}/trafficrules", traffic_rule_payload) return __mock_requests @@ -262,6 +272,12 @@ def fixture_system_information_data() -> list[dict[str, Any]]: ] +@pytest.fixture(name="traffic_rule_payload") +def traffic_rule_payload_data() -> list[dict[str, Any]]: + """Traffic rule data.""" + return [] + + @pytest.fixture(name="wlan_payload") def fixture_wlan_data() -> list[dict[str, Any]]: """WLAN data.""" diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index b0ae8bde445..daf64301c8e 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -774,6 +774,37 @@ PORT_FORWARD_PLEX = { "src": "any", } +TRAFFIC_RULE = { + "_id": "6452cd9b859d5b11aa002ea1", + "action": "BLOCK", + "app_category_ids": [], + "app_ids": [], + "bandwidth_limit": { + "download_limit_kbps": 1024, + "enabled": False, + "upload_limit_kbps": 1024, + }, + "description": "Test Traffic Rule", + "name": "Test Traffic Rule", + "domains": [], + "enabled": True, + "ip_addresses": [], + "ip_ranges": [], + "matching_target": "INTERNET", + "network_ids": [], + "regions": [], + "schedule": { + "date_end": "2023-05-10", + "date_start": "2023-05-03", + "mode": "ALWAYS", + "repeat_on_days": [], + "time_all_day": False, + "time_range_end": "12:00", + "time_range_start": "09:00", + }, + "target_devices": [{"client_mac": CLIENT_1["mac"], "type": "CLIENT"}], +} + @pytest.mark.parametrize("client_payload", [[CONTROLLER_HOST]]) @pytest.mark.parametrize("device_payload", [[DEVICE_1]]) @@ -1072,6 +1103,64 @@ async def test_dpi_switches_add_second_app( assert hass.states.get("switch.block_media_streaming").state == STATE_ON +@pytest.mark.parametrize(("traffic_rule_payload"), [([TRAFFIC_RULE])]) +@pytest.mark.usefixtures("config_entry_setup") +async def test_traffic_rules( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + mock_websocket_message, + config_entry_setup: ConfigEntry, + traffic_rule_payload: list[dict[str, Any]], +) -> None: + """Test control of UniFi traffic rules.""" + + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 + + # Validate state object + switch_1 = hass.states.get("switch.unifi_network_test_traffic_rule") + assert switch_1.state == STATE_ON + assert switch_1.attributes.get(ATTR_DEVICE_CLASS) == SwitchDeviceClass.SWITCH + + traffic_rule = deepcopy(traffic_rule_payload[0]) + + # Disable traffic rule + aioclient_mock.put( + f"https://{config_entry_setup.data[CONF_HOST]}:1234" + f"/v2/api/site/{config_entry_setup.data[CONF_SITE_ID]}/trafficrules/{traffic_rule['_id']}", + ) + + call_count = aioclient_mock.call_count + + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_off", + {"entity_id": "switch.unifi_network_test_traffic_rule"}, + blocking=True, + ) + # Updating the value for traffic rules will make another call to retrieve the values + assert aioclient_mock.call_count == call_count + 2 + expected_disable_call = deepcopy(traffic_rule) + expected_disable_call["enabled"] = False + + assert aioclient_mock.mock_calls[call_count][2] == expected_disable_call + + call_count = aioclient_mock.call_count + + # Enable traffic rule + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_on", + {"entity_id": "switch.unifi_network_test_traffic_rule"}, + blocking=True, + ) + + expected_enable_call = deepcopy(traffic_rule) + expected_enable_call["enabled"] = True + + assert aioclient_mock.call_count == call_count + 2 + assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call + + @pytest.mark.parametrize( ("device_payload", "entity_id", "outlet_index", "expected_switches"), [