Expose async_scanner_devices_by_address from the bluetooth api (#83733)

Co-authored-by: J. Nick Koston <nick@koston.org>
fixes undefined
This commit is contained in:
David Buezas 2023-01-09 01:06:32 +01:00 committed by GitHub
parent 06a35fb7db
commit 112b2c22f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 179 additions and 47 deletions

View file

@ -58,9 +58,10 @@ from .api import (
async_register_scanner, async_register_scanner,
async_scanner_by_source, async_scanner_by_source,
async_scanner_count, async_scanner_count,
async_scanner_devices_by_address,
async_track_unavailable, async_track_unavailable,
) )
from .base_scanner import BaseHaRemoteScanner, BaseHaScanner from .base_scanner import BaseHaRemoteScanner, BaseHaScanner, BluetoothScannerDevice
from .const import ( from .const import (
BLUETOOTH_DISCOVERY_COOLDOWN_SECONDS, BLUETOOTH_DISCOVERY_COOLDOWN_SECONDS,
CONF_ADAPTER, CONF_ADAPTER,
@ -99,6 +100,7 @@ __all__ = [
"async_track_unavailable", "async_track_unavailable",
"async_scanner_by_source", "async_scanner_by_source",
"async_scanner_count", "async_scanner_count",
"async_scanner_devices_by_address",
"BaseHaScanner", "BaseHaScanner",
"BaseHaRemoteScanner", "BaseHaRemoteScanner",
"BluetoothCallbackMatcher", "BluetoothCallbackMatcher",
@ -107,6 +109,7 @@ __all__ = [
"BluetoothServiceInfoBleak", "BluetoothServiceInfoBleak",
"BluetoothScanningMode", "BluetoothScanningMode",
"BluetoothCallback", "BluetoothCallback",
"BluetoothScannerDevice",
"HaBluetoothConnector", "HaBluetoothConnector",
"SOURCE_LOCAL", "SOURCE_LOCAL",
"FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS", "FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS",

View file

@ -13,7 +13,7 @@ from home_assistant_bluetooth import BluetoothServiceInfoBleak
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback
from .base_scanner import BaseHaScanner from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import DATA_MANAGER from .const import DATA_MANAGER
from .manager import BluetoothManager from .manager import BluetoothManager
from .match import BluetoothCallbackMatcher from .match import BluetoothCallbackMatcher
@ -93,6 +93,14 @@ def async_ble_device_from_address(
return _get_manager(hass).async_ble_device_from_address(address, connectable) return _get_manager(hass).async_ble_device_from_address(address, connectable)
@hass_callback
def async_scanner_devices_by_address(
hass: HomeAssistant, address: str, connectable: bool = True
) -> list[BluetoothScannerDevice]:
"""Return all discovered BluetoothScannerDevice for an address."""
return _get_manager(hass).async_scanner_devices_by_address(address, connectable)
@hass_callback @hass_callback
def async_address_present( def async_address_present(
hass: HomeAssistant, address: str, connectable: bool = True hass: HomeAssistant, address: str, connectable: bool = True

View file

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
import datetime import datetime
from datetime import timedelta from datetime import timedelta
import logging import logging
@ -39,6 +40,15 @@ MONOTONIC_TIME: Final = monotonic_time_coarse
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclass
class BluetoothScannerDevice:
"""Data for a bluetooth device from a given scanner."""
scanner: BaseHaScanner
ble_device: BLEDevice
advertisement: AdvertisementData
class BaseHaScanner(ABC): class BaseHaScanner(ABC):
"""Base class for Ha Scanners.""" """Base class for Ha Scanners."""

View file

@ -29,7 +29,7 @@ from homeassistant.helpers.event import async_track_time_interval
from homeassistant.util.dt import monotonic_time_coarse from homeassistant.util.dt import monotonic_time_coarse
from .advertisement_tracker import AdvertisementTracker from .advertisement_tracker import AdvertisementTracker
from .base_scanner import BaseHaScanner from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import ( from .const import (
FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS, FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS,
UNAVAILABLE_TRACK_SECONDS, UNAVAILABLE_TRACK_SECONDS,
@ -217,18 +217,22 @@ class BluetoothManager:
uninstall_multiple_bleak_catcher() uninstall_multiple_bleak_catcher()
@hass_callback @hass_callback
def async_get_scanner_discovered_devices_and_advertisement_data_by_address( def async_scanner_devices_by_address(
self, address: str, connectable: bool self, address: str, connectable: bool
) -> list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]]: ) -> list[BluetoothScannerDevice]:
"""Get scanner, devices, and advertisement_data by address.""" """Get BluetoothScannerDevice by address."""
types_ = (True,) if connectable else (True, False) scanners = self._get_scanners_by_type(True)
results: list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]] = [] if not connectable:
for type_ in types_: scanners.extend(self._get_scanners_by_type(False))
for scanner in self._get_scanners_by_type(type_): return [
devices_and_adv_data = scanner.discovered_devices_and_advertisement_data BluetoothScannerDevice(scanner, *device_adv)
if device_adv_data := devices_and_adv_data.get(address): for scanner in scanners
results.append((scanner, *device_adv_data)) if (
return results device_adv := scanner.discovered_devices_and_advertisement_data.get(
address
)
)
]
@hass_callback @hass_callback
def _async_all_discovered_addresses(self, connectable: bool) -> Iterable[str]: def _async_all_discovered_addresses(self, connectable: bool) -> Iterable[str]:

View file

@ -12,11 +12,7 @@ from typing import TYPE_CHECKING, Any, Final
from bleak import BleakClient, BleakError from bleak import BleakClient, BleakError
from bleak.backends.client import BaseBleakClient, get_platform_client_backend_type from bleak.backends.client import BaseBleakClient, get_platform_client_backend_type
from bleak.backends.device import BLEDevice from bleak.backends.device import BLEDevice
from bleak.backends.scanner import ( from bleak.backends.scanner import AdvertisementDataCallback, BaseBleakScanner
AdvertisementData,
AdvertisementDataCallback,
BaseBleakScanner,
)
from bleak_retry_connector import ( from bleak_retry_connector import (
NO_RSSI_VALUE, NO_RSSI_VALUE,
ble_device_description, ble_device_description,
@ -28,7 +24,7 @@ from homeassistant.core import CALLBACK_TYPE, callback as hass_callback
from homeassistant.helpers.frame import report from homeassistant.helpers.frame import report
from . import models from . import models
from .base_scanner import BaseHaScanner from .base_scanner import BaseHaScanner, BluetoothScannerDevice
FILTER_UUIDS: Final = "UUIDs" FILTER_UUIDS: Final = "UUIDs"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -149,9 +145,7 @@ class HaBleakScannerWrapper(BaseBleakScanner):
def _rssi_sorter_with_connection_failure_penalty( def _rssi_sorter_with_connection_failure_penalty(
scanner_device_advertisement_data: tuple[ device: BluetoothScannerDevice,
BaseHaScanner, BLEDevice, AdvertisementData
],
connection_failure_count: dict[BaseHaScanner, int], connection_failure_count: dict[BaseHaScanner, int],
rssi_diff: int, rssi_diff: int,
) -> float: ) -> float:
@ -168,9 +162,8 @@ def _rssi_sorter_with_connection_failure_penalty(
best adapter twice before moving on to the next best adapter since best adapter twice before moving on to the next best adapter since
the first failure may be a transient service resolution issue. the first failure may be a transient service resolution issue.
""" """
scanner, _, advertisement_data = scanner_device_advertisement_data base_rssi = device.advertisement.rssi or NO_RSSI_VALUE
base_rssi = advertisement_data.rssi or NO_RSSI_VALUE if connect_failures := connection_failure_count.get(device.scanner):
if connect_failures := connection_failure_count.get(scanner):
if connect_failures > 1 and not rssi_diff: if connect_failures > 1 and not rssi_diff:
rssi_diff = 1 rssi_diff = 1
return base_rssi - (rssi_diff * connect_failures * 0.51) return base_rssi - (rssi_diff * connect_failures * 0.51)
@ -300,14 +293,10 @@ class HaBleakClientWrapper(BleakClient):
that has a free connection slot. that has a free connection slot.
""" """
address = self.__address address = self.__address
scanner_device_advertisement_datas = manager.async_get_scanner_discovered_devices_and_advertisement_data_by_address( # noqa: E501 devices = manager.async_scanner_devices_by_address(self.__address, True)
address, True sorted_devices = sorted(
) devices,
sorted_scanner_device_advertisement_datas = sorted( key=lambda device: device.advertisement.rssi or NO_RSSI_VALUE,
scanner_device_advertisement_datas,
key=lambda scanner_device_advertisement_data: (
scanner_device_advertisement_data[2].rssi or NO_RSSI_VALUE
),
reverse=True, reverse=True,
) )
@ -315,31 +304,28 @@ class HaBleakClientWrapper(BleakClient):
# to prefer the adapter/scanner with the less failures so # to prefer the adapter/scanner with the less failures so
# we don't keep trying to connect with an adapter # we don't keep trying to connect with an adapter
# that is failing # that is failing
if ( if self.__connect_failures and len(sorted_devices) > 1:
self.__connect_failures
and len(sorted_scanner_device_advertisement_datas) > 1
):
# We use the rssi diff between to the top two # We use the rssi diff between to the top two
# to adjust the rssi sorter so that each failure # to adjust the rssi sorter so that each failure
# will reduce the rssi sorter by the diff amount # will reduce the rssi sorter by the diff amount
rssi_diff = ( rssi_diff = (
sorted_scanner_device_advertisement_datas[0][2].rssi sorted_devices[0].advertisement.rssi
- sorted_scanner_device_advertisement_datas[1][2].rssi - sorted_devices[1].advertisement.rssi
) )
adjusted_rssi_sorter = partial( adjusted_rssi_sorter = partial(
_rssi_sorter_with_connection_failure_penalty, _rssi_sorter_with_connection_failure_penalty,
connection_failure_count=self.__connect_failures, connection_failure_count=self.__connect_failures,
rssi_diff=rssi_diff, rssi_diff=rssi_diff,
) )
sorted_scanner_device_advertisement_datas = sorted( sorted_devices = sorted(
scanner_device_advertisement_datas, devices,
key=adjusted_rssi_sorter, key=adjusted_rssi_sorter,
reverse=True, reverse=True,
) )
for (scanner, ble_device, _) in sorted_scanner_device_advertisement_datas: for device in sorted_devices:
if backend := self._async_get_backend_for_ble_device( if backend := self._async_get_backend_for_ble_device(
manager, scanner, ble_device manager, device.scanner, device.ble_device
): ):
return backend return backend

View file

@ -1,10 +1,18 @@
"""Tests for the Bluetooth integration API.""" """Tests for the Bluetooth integration API."""
from homeassistant.components import bluetooth from bleak.backends.scanner import AdvertisementData, BLEDevice
from homeassistant.components.bluetooth import async_scanner_by_source
from . import FakeScanner from homeassistant.components import bluetooth
from homeassistant.components.bluetooth import (
BaseHaRemoteScanner,
BaseHaScanner,
HaBluetoothConnector,
async_scanner_by_source,
async_scanner_devices_by_address,
)
from . import FakeScanner, MockBleakClient, _get_manager, generate_advertisement_data
async def test_scanner_by_source(hass, enable_bluetooth): async def test_scanner_by_source(hass, enable_bluetooth):
@ -16,3 +24,116 @@ async def test_scanner_by_source(hass, enable_bluetooth):
assert async_scanner_by_source(hass, "hci2") is hci2_scanner assert async_scanner_by_source(hass, "hci2") is hci2_scanner
cancel_hci2() cancel_hci2()
assert async_scanner_by_source(hass, "hci2") is None assert async_scanner_by_source(hass, "hci2") is None
async def test_async_scanner_devices_by_address_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with connectable devices."""
manager = _get_manager()
class FakeInjectableScanner(BaseHaRemoteScanner):
def inject_advertisement(
self, device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Inject an advertisement."""
self._async_on_advertisement(
device.address,
advertisement_data.rssi,
device.name,
advertisement_data.service_uuids,
advertisement_data.service_data,
advertisement_data.manufacturer_data,
advertisement_data.tx_power,
{"scanner_specific_data": "test"},
)
new_info_callback = manager.scanner_adv_received
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeInjectableScanner(
hass, "esp32", "esp32", new_info_callback, connector, False
)
unsetup = scanner.async_setup()
cancel = manager.async_register_scanner(scanner, True)
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
scanner.inject_advertisement(switchbot_device, switchbot_device_adv)
assert async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
) == async_scanner_devices_by_address(hass, "44:44:33:11:23:45", connectable=False)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
unsetup()
cancel()
async def test_async_scanner_devices_by_address_non_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with non-connectable devices."""
manager = _get_manager()
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
class FakeStaticScanner(BaseHaScanner):
@property
def discovered_devices(self) -> list[BLEDevice]:
"""Return a list of discovered devices."""
return [switchbot_device]
@property
def discovered_devices_and_advertisement_data(
self,
) -> dict[str, tuple[BLEDevice, AdvertisementData]]:
"""Return a list of discovered devices and their advertisement data."""
return {switchbot_device.address: (switchbot_device, switchbot_device_adv)}
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeStaticScanner(hass, "esp32", "esp32", connector)
cancel = manager.async_register_scanner(scanner, False)
assert scanner.discovered_devices_and_advertisement_data == {
switchbot_device.address: (switchbot_device, switchbot_device_adv)
}
assert (
async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
)
== []
)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
cancel()