Replace pyunifi with aiounifi in UniFi device tracker (#24149)

* Replace pyunifi with aiounifi

* Fix tests

* Add sslcontext

* Fix tests

* Fix import order
This commit is contained in:
Robert Svensson 2019-06-02 18:24:13 +02:00 committed by GitHub
parent 16a846b1e7
commit 4d4fd19f87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 103 additions and 87 deletions

View file

@ -84,6 +84,7 @@ class UnifiFlowHandler(config_entries.ConfigFlow):
try: try:
desc = user_input.get(CONF_SITE_ID, self.desc) desc = user_input.get(CONF_SITE_ID, self.desc)
print(self.sites)
for site in self.sites.values(): for site in self.sites.values():
if desc == site['desc']: if desc == site['desc']:
if site['role'] != 'admin': if site['role'] != 'admin':

View file

@ -1,5 +1,6 @@
"""UniFi Controller abstraction.""" """UniFi Controller abstraction."""
import asyncio import asyncio
import ssl
import async_timeout import async_timeout
from aiohttp import CookieJar from aiohttp import CookieJar
@ -81,15 +82,19 @@ async def get_controller(
"""Create a controller object and verify authentication.""" """Create a controller object and verify authentication."""
import aiounifi import aiounifi
sslcontext = None
if verify_ssl: if verify_ssl:
session = aiohttp_client.async_get_clientsession(hass) session = aiohttp_client.async_get_clientsession(hass)
if isinstance(verify_ssl, str):
sslcontext = ssl.create_default_context(cafile=verify_ssl)
else: else:
session = aiohttp_client.async_create_clientsession( session = aiohttp_client.async_create_clientsession(
hass, verify_ssl=verify_ssl, cookie_jar=CookieJar(unsafe=True)) hass, verify_ssl=verify_ssl, cookie_jar=CookieJar(unsafe=True))
controller = aiounifi.Controller( controller = aiounifi.Controller(
host, username=username, password=password, port=port, site=site, host, username=username, password=password, port=port, site=site,
websession=session websession=session, sslcontext=sslcontext
) )
try: try:

View file

@ -1,8 +1,13 @@
"""Support for Unifi WAP controllers.""" """Support for Unifi WAP controllers."""
import asyncio
import logging import logging
from datetime import timedelta from datetime import timedelta
import voluptuous as vol import voluptuous as vol
import async_timeout
import aiounifi
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import ( from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner) DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
@ -10,6 +15,9 @@ from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD
from homeassistant.const import CONF_VERIFY_SSL, CONF_MONITORED_CONDITIONS from homeassistant.const import CONF_VERIFY_SSL, CONF_MONITORED_CONDITIONS
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .controller import get_controller
from .errors import AuthenticationRequired, CannotConnect
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_PORT = 'port' CONF_PORT = 'port'
CONF_SITE_ID = 'site_id' CONF_SITE_ID = 'site_id'
@ -54,10 +62,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
}) })
def get_scanner(hass, config): async def async_get_scanner(hass, config):
"""Set up the Unifi device_tracker.""" """Set up the Unifi device_tracker."""
from pyunifi.controller import Controller, APIError
host = config[DOMAIN].get(CONF_HOST) host = config[DOMAIN].get(CONF_HOST)
username = config[DOMAIN].get(CONF_USERNAME) username = config[DOMAIN].get(CONF_USERNAME)
password = config[DOMAIN].get(CONF_PASSWORD) password = config[DOMAIN].get(CONF_PASSWORD)
@ -69,9 +75,11 @@ def get_scanner(hass, config):
ssid_filter = config[DOMAIN].get(CONF_SSID_FILTER) ssid_filter = config[DOMAIN].get(CONF_SSID_FILTER)
try: try:
ctrl = Controller(host, username, password, port, version='v4', controller = await get_controller(
site_id=site_id, ssl_verify=verify_ssl) hass, host, username, password, port, site_id, verify_ssl)
except APIError as ex: await controller.initialize()
except (AuthenticationRequired, CannotConnect) as ex:
_LOGGER.error("Failed to connect to Unifi: %s", ex) _LOGGER.error("Failed to connect to Unifi: %s", ex)
hass.components.persistent_notification.create( hass.components.persistent_notification.create(
'Failed to connect to Unifi. ' 'Failed to connect to Unifi. '
@ -82,8 +90,8 @@ def get_scanner(hass, config):
notification_id=NOTIFICATION_ID) notification_id=NOTIFICATION_ID)
return False return False
return UnifiScanner(ctrl, detection_time, ssid_filter, return UnifiScanner(
monitored_conditions) controller, detection_time, ssid_filter, monitored_conditions)
class UnifiScanner(DeviceScanner): class UnifiScanner(DeviceScanner):
@ -92,36 +100,45 @@ class UnifiScanner(DeviceScanner):
def __init__(self, controller, detection_time: timedelta, def __init__(self, controller, detection_time: timedelta,
ssid_filter, monitored_conditions) -> None: ssid_filter, monitored_conditions) -> None:
"""Initialize the scanner.""" """Initialize the scanner."""
self.controller = controller
self._detection_time = detection_time self._detection_time = detection_time
self._controller = controller
self._ssid_filter = ssid_filter self._ssid_filter = ssid_filter
self._monitored_conditions = monitored_conditions self._monitored_conditions = monitored_conditions
self._update() self._clients = {}
def _update(self): async def async_update(self):
"""Get the clients from the device.""" """Get the clients from the device."""
from pyunifi.controller import APIError
try: try:
clients = self._controller.get_clients() await self.controller.clients.update()
except APIError as ex: clients = self.controller.clients.values()
_LOGGER.error("Failed to scan clients: %s", ex)
except aiounifi.LoginRequired:
try:
with async_timeout.timeout(5):
await self.controller.login()
except (asyncio.TimeoutError, aiounifi.AiounifiException):
clients = []
except aiounifi.AiounifiException:
clients = [] clients = []
# Filter clients to provided SSID list # Filter clients to provided SSID list
if self._ssid_filter: if self._ssid_filter:
clients = [client for client in clients clients = [
if 'essid' in client and client for client in clients
client['essid'] in self._ssid_filter] if client.essid in self._ssid_filter
]
self._clients = { self._clients = {
client['mac']: client client.raw['mac']: client.raw
for client in clients for client in clients
if (dt_util.utcnow() - dt_util.utc_from_timestamp(float( if (dt_util.utcnow() - dt_util.utc_from_timestamp(float(
client['last_seen']))) < self._detection_time} client.last_seen))) < self._detection_time
}
def scan_devices(self): async def async_scan_devices(self):
"""Scan for devices.""" """Scan for devices."""
self._update() await self.async_update()
return self._clients.keys() return self._clients.keys()
def get_device_name(self, device): def get_device_name(self, device):

View file

@ -4,8 +4,7 @@
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/components/unifi", "documentation": "https://www.home-assistant.io/components/unifi",
"requirements": [ "requirements": [
"aiounifi==4", "aiounifi==6"
"pyunifi==2.16"
], ],
"dependencies": [], "dependencies": [],
"codeowners": [ "codeowners": [

View file

@ -163,7 +163,7 @@ aiopvapi==1.6.14
aioswitcher==2019.3.21 aioswitcher==2019.3.21
# homeassistant.components.unifi # homeassistant.components.unifi
aiounifi==4 aiounifi==6
# homeassistant.components.aladdin_connect # homeassistant.components.aladdin_connect
aladdin_connect==0.3 aladdin_connect==0.3
@ -1488,9 +1488,6 @@ pytrafikverket==0.1.5.9
# homeassistant.components.ubee # homeassistant.components.ubee
pyubee==0.6 pyubee==0.6
# homeassistant.components.unifi
pyunifi==2.16
# homeassistant.components.uptimerobot # homeassistant.components.uptimerobot
pyuptimerobot==0.0.5 pyuptimerobot==0.0.5

View file

@ -61,7 +61,7 @@ aiohue==1.9.1
aioswitcher==2019.3.21 aioswitcher==2019.3.21
# homeassistant.components.unifi # homeassistant.components.unifi
aiounifi==4 aiounifi==6
# homeassistant.components.ambiclimate # homeassistant.components.ambiclimate
ambiclimate==0.1.2 ambiclimate==0.1.2
@ -294,9 +294,6 @@ python_awair==0.0.4
# homeassistant.components.tradfri # homeassistant.components.tradfri
pytradfri[async]==6.0.1 pytradfri[async]==6.0.1
# homeassistant.components.unifi
pyunifi==2.16
# homeassistant.components.html5 # homeassistant.components.html5
pywebpush==1.9.2 pywebpush==1.9.2

View file

@ -1,8 +1,6 @@
"""The tests for the Unifi WAP device tracker platform.""" """The tests for the Unifi WAP device tracker platform."""
from unittest import mock from unittest import mock
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pyunifi.controller import APIError
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -13,13 +11,20 @@ import homeassistant.components.unifi.device_tracker as unifi
from homeassistant.const import (CONF_HOST, CONF_USERNAME, CONF_PASSWORD, from homeassistant.const import (CONF_HOST, CONF_USERNAME, CONF_PASSWORD,
CONF_PLATFORM, CONF_VERIFY_SSL, CONF_PLATFORM, CONF_VERIFY_SSL,
CONF_MONITORED_CONDITIONS) CONF_MONITORED_CONDITIONS)
from tests.common import mock_coro
from asynctest import CoroutineMock
from aiounifi.clients import Clients
DEFAULT_DETECTION_TIME = timedelta(seconds=300) DEFAULT_DETECTION_TIME = timedelta(seconds=300)
@pytest.fixture @pytest.fixture
def mock_ctrl(): def mock_ctrl():
"""Mock pyunifi.""" """Mock pyunifi."""
with mock.patch('pyunifi.controller.Controller') as mock_control: with mock.patch('aiounifi.Controller') as mock_control:
mock_control.return_value.login.return_value = mock_coro()
mock_control.return_value.initialize.return_value = mock_coro()
yield mock_control yield mock_control
@ -33,7 +38,7 @@ def mock_scanner():
@mock.patch('os.access', return_value=True) @mock.patch('os.access', return_value=True)
@mock.patch('os.path.isfile', mock.Mock(return_value=True)) @mock.patch('os.path.isfile', mock.Mock(return_value=True))
def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): async def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl):
"""Test the setup with a string for ssl_verify. """Test the setup with a string for ssl_verify.
Representing the absolute path to a CA certificate bundle. Representing the absolute path to a CA certificate bundle.
@ -46,12 +51,9 @@ def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl):
CONF_VERIFY_SSL: "/tmp/unifi.crt" CONF_VERIFY_SSL: "/tmp/unifi.crt"
}) })
} }
result = unifi.get_scanner(hass, config) result = await unifi.async_get_scanner(hass, config)
assert mock_scanner.return_value == result assert mock_scanner.return_value == result
assert mock_ctrl.call_count == 1 assert mock_ctrl.call_count == 1
assert mock_ctrl.mock_calls[0] == \
mock.call('localhost', 'foo', 'password', 8443,
version='v4', site_id='default', ssl_verify="/tmp/unifi.crt")
assert mock_scanner.call_count == 1 assert mock_scanner.call_count == 1
assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, assert mock_scanner.call_args == mock.call(mock_ctrl.return_value,
@ -59,7 +61,7 @@ def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl):
None, None) None, None)
def test_config_minimal(hass, mock_scanner, mock_ctrl): async def test_config_minimal(hass, mock_scanner, mock_ctrl):
"""Test the setup with minimal configuration.""" """Test the setup with minimal configuration."""
config = { config = {
DOMAIN: unifi.PLATFORM_SCHEMA({ DOMAIN: unifi.PLATFORM_SCHEMA({
@ -68,12 +70,10 @@ def test_config_minimal(hass, mock_scanner, mock_ctrl):
CONF_PASSWORD: 'password', CONF_PASSWORD: 'password',
}) })
} }
result = unifi.get_scanner(hass, config)
result = await unifi.async_get_scanner(hass, config)
assert mock_scanner.return_value == result assert mock_scanner.return_value == result
assert mock_ctrl.call_count == 1 assert mock_ctrl.call_count == 1
assert mock_ctrl.mock_calls[0] == \
mock.call('localhost', 'foo', 'password', 8443,
version='v4', site_id='default', ssl_verify=True)
assert mock_scanner.call_count == 1 assert mock_scanner.call_count == 1
assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, assert mock_scanner.call_args == mock.call(mock_ctrl.return_value,
@ -81,7 +81,7 @@ def test_config_minimal(hass, mock_scanner, mock_ctrl):
None, None) None, None)
def test_config_full(hass, mock_scanner, mock_ctrl): async def test_config_full(hass, mock_scanner, mock_ctrl):
"""Test the setup with full configuration.""" """Test the setup with full configuration."""
config = { config = {
DOMAIN: unifi.PLATFORM_SCHEMA({ DOMAIN: unifi.PLATFORM_SCHEMA({
@ -96,12 +96,9 @@ def test_config_full(hass, mock_scanner, mock_ctrl):
'detection_time': 300, 'detection_time': 300,
}) })
} }
result = unifi.get_scanner(hass, config) result = await unifi.async_get_scanner(hass, config)
assert mock_scanner.return_value == result assert mock_scanner.return_value == result
assert mock_ctrl.call_count == 1 assert mock_ctrl.call_count == 1
assert mock_ctrl.call_args == \
mock.call('myhost', 'foo', 'password', 123,
version='v4', site_id='abcdef01', ssl_verify=False)
assert mock_scanner.call_count == 1 assert mock_scanner.call_count == 1
assert mock_scanner.call_args == mock.call( assert mock_scanner.call_args == mock.call(
@ -137,7 +134,7 @@ def test_config_error():
}) })
def test_config_controller_failed(hass, mock_ctrl, mock_scanner): async def test_config_controller_failed(hass, mock_ctrl, mock_scanner):
"""Test for controller failure.""" """Test for controller failure."""
config = { config = {
'device_tracker': { 'device_tracker': {
@ -146,13 +143,12 @@ def test_config_controller_failed(hass, mock_ctrl, mock_scanner):
CONF_PASSWORD: 'password', CONF_PASSWORD: 'password',
} }
} }
mock_ctrl.side_effect = APIError( mock_ctrl.side_effect = unifi.CannotConnect
'/', 500, 'foo', {}, None) result = await unifi.async_get_scanner(hass, config)
result = unifi.get_scanner(hass, config)
assert result is False assert result is False
def test_scanner_update(): async def test_scanner_update():
"""Test the scanner update.""" """Test the scanner update."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
fake_clients = [ fake_clients = [
@ -161,21 +157,20 @@ def test_scanner_update():
{'mac': '234', 'essid': 'barnet', {'mac': '234', 'essid': 'barnet',
'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, 'last_seen': dt_util.as_timestamp(dt_util.utcnow())},
] ]
ctrl.get_clients.return_value = fake_clients ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients))
unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None)
assert ctrl.get_clients.call_count == 1 await scnr.async_update()
assert ctrl.get_clients.call_args == mock.call() assert len(scnr._clients) == 2
def test_scanner_update_error(): def test_scanner_update_error():
"""Test the scanner update for error.""" """Test the scanner update for error."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
ctrl.get_clients.side_effect = APIError( ctrl.get_clients.side_effect = unifi.aiounifi.AiounifiException
'/', 500, 'foo', {}, None)
unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None)
def test_scan_devices(): async def test_scan_devices():
"""Test the scanning for devices.""" """Test the scanning for devices."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
fake_clients = [ fake_clients = [
@ -184,12 +179,13 @@ def test_scan_devices():
{'mac': '234', 'essid': 'barnet', {'mac': '234', 'essid': 'barnet',
'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, 'last_seen': dt_util.as_timestamp(dt_util.utcnow())},
] ]
ctrl.get_clients.return_value = fake_clients ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients))
scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None)
assert set(scanner.scan_devices()) == set(['123', '234']) await scnr.async_update()
assert set(await scnr.async_scan_devices()) == set(['123', '234'])
def test_scan_devices_filtered(): async def test_scan_devices_filtered():
"""Test the scanning for devices based on SSID.""" """Test the scanning for devices based on SSID."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
fake_clients = [ fake_clients = [
@ -204,13 +200,13 @@ def test_scan_devices_filtered():
] ]
ssid_filter = ['foonet', 'barnet'] ssid_filter = ['foonet', 'barnet']
ctrl.get_clients.return_value = fake_clients ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients))
scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, ssid_filter, scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, ssid_filter, None)
None) await scnr.async_update()
assert set(scanner.scan_devices()) == set(['123', '234', '890']) assert set(await scnr.async_scan_devices()) == set(['123', '234', '890'])
def test_get_device_name(): async def test_get_device_name():
"""Test the getting of device names.""" """Test the getting of device names."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
fake_clients = [ fake_clients = [
@ -226,15 +222,16 @@ def test_get_device_name():
'essid': 'barnet', 'essid': 'barnet',
'last_seen': '1504786810'}, 'last_seen': '1504786810'},
] ]
ctrl.get_clients.return_value = fake_clients ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients))
scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None)
assert scanner.get_device_name('123') == 'foobar' await scnr.async_update()
assert scanner.get_device_name('234') == 'Nice Name' assert scnr.get_device_name('123') == 'foobar'
assert scanner.get_device_name('456') is None assert scnr.get_device_name('234') == 'Nice Name'
assert scanner.get_device_name('unknown') is None assert scnr.get_device_name('456') is None
assert scnr.get_device_name('unknown') is None
def test_monitored_conditions(): async def test_monitored_conditions():
"""Test the filtering of attributes.""" """Test the filtering of attributes."""
ctrl = mock.MagicMock() ctrl = mock.MagicMock()
fake_clients = [ fake_clients = [
@ -254,16 +251,17 @@ def test_monitored_conditions():
'essid': 'barnet', 'essid': 'barnet',
'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, 'last_seen': dt_util.as_timestamp(dt_util.utcnow())},
] ]
ctrl.get_clients.return_value = fake_clients ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients))
scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None,
['essid', 'signal', 'latest_assoc_time']) ['essid', 'signal', 'latest_assoc_time'])
assert scanner.get_extra_attributes('123') == { await scnr.async_update()
assert scnr.get_extra_attributes('123') == {
'essid': 'barnet', 'essid': 'barnet',
'signal': -60, 'signal': -60,
'latest_assoc_time': datetime(2000, 1, 1, 0, 0, tzinfo=dt_util.UTC) 'latest_assoc_time': datetime(2000, 1, 1, 0, 0, tzinfo=dt_util.UTC)
} }
assert scanner.get_extra_attributes('234') == { assert scnr.get_extra_attributes('234') == {
'essid': 'barnet', 'essid': 'barnet',
'signal': -42 'signal': -42
} }
assert scanner.get_extra_attributes('456') == {'essid': 'barnet'} assert scnr.get_extra_attributes('456') == {'essid': 'barnet'}

View file

@ -146,7 +146,8 @@ async def test_flow_works(hass, aioclient_mock):
flow.hass = hass flow.hass = hass
with patch('aiounifi.Controller') as mock_controller: with patch('aiounifi.Controller') as mock_controller:
def mock_constructor(host, username, password, port, site, websession): def mock_constructor(
host, username, password, port, site, websession, sslcontext):
"""Fake the controller constructor.""" """Fake the controller constructor."""
mock_controller.host = host mock_controller.host = host
mock_controller.username = username mock_controller.username = username
@ -254,7 +255,8 @@ async def test_user_permissions_low(hass, aioclient_mock):
flow.hass = hass flow.hass = hass
with patch('aiounifi.Controller') as mock_controller: with patch('aiounifi.Controller') as mock_controller:
def mock_constructor(host, username, password, port, site, websession): def mock_constructor(
host, username, password, port, site, websession, sslcontext):
"""Fake the controller constructor.""" """Fake the controller constructor."""
mock_controller.host = host mock_controller.host = host
mock_controller.username = username mock_controller.username = username