Expose async_scanner_devices_by_address from the bluetooth api (#83733)

Co-authored-by: J. Nick Koston <nick@koston.org>
fixes undefined
This commit is contained in:
David Buezas 2023-01-09 01:06:32 +01:00 committed by GitHub
parent 06a35fb7db
commit 112b2c22f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 179 additions and 47 deletions

View file

@ -58,9 +58,10 @@ from .api import (
async_register_scanner,
async_scanner_by_source,
async_scanner_count,
async_scanner_devices_by_address,
async_track_unavailable,
)
from .base_scanner import BaseHaRemoteScanner, BaseHaScanner
from .base_scanner import BaseHaRemoteScanner, BaseHaScanner, BluetoothScannerDevice
from .const import (
BLUETOOTH_DISCOVERY_COOLDOWN_SECONDS,
CONF_ADAPTER,
@ -99,6 +100,7 @@ __all__ = [
"async_track_unavailable",
"async_scanner_by_source",
"async_scanner_count",
"async_scanner_devices_by_address",
"BaseHaScanner",
"BaseHaRemoteScanner",
"BluetoothCallbackMatcher",
@ -107,6 +109,7 @@ __all__ = [
"BluetoothServiceInfoBleak",
"BluetoothScanningMode",
"BluetoothCallback",
"BluetoothScannerDevice",
"HaBluetoothConnector",
"SOURCE_LOCAL",
"FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS",

View file

@ -13,7 +13,7 @@ from home_assistant_bluetooth import BluetoothServiceInfoBleak
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback
from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import DATA_MANAGER
from .manager import BluetoothManager
from .match import BluetoothCallbackMatcher
@ -93,6 +93,14 @@ def async_ble_device_from_address(
return _get_manager(hass).async_ble_device_from_address(address, connectable)
@hass_callback
def async_scanner_devices_by_address(
hass: HomeAssistant, address: str, connectable: bool = True
) -> list[BluetoothScannerDevice]:
"""Return all discovered BluetoothScannerDevice for an address."""
return _get_manager(hass).async_scanner_devices_by_address(address, connectable)
@hass_callback
def async_address_present(
hass: HomeAssistant, address: str, connectable: bool = True

View file

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
import datetime
from datetime import timedelta
import logging
@ -39,6 +40,15 @@ MONOTONIC_TIME: Final = monotonic_time_coarse
_LOGGER = logging.getLogger(__name__)
@dataclass
class BluetoothScannerDevice:
"""Data for a bluetooth device from a given scanner."""
scanner: BaseHaScanner
ble_device: BLEDevice
advertisement: AdvertisementData
class BaseHaScanner(ABC):
"""Base class for Ha Scanners."""

View file

@ -29,7 +29,7 @@ from homeassistant.helpers.event import async_track_time_interval
from homeassistant.util.dt import monotonic_time_coarse
from .advertisement_tracker import AdvertisementTracker
from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import (
FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS,
UNAVAILABLE_TRACK_SECONDS,
@ -217,18 +217,22 @@ class BluetoothManager:
uninstall_multiple_bleak_catcher()
@hass_callback
def async_get_scanner_discovered_devices_and_advertisement_data_by_address(
def async_scanner_devices_by_address(
self, address: str, connectable: bool
) -> list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]]:
"""Get scanner, devices, and advertisement_data by address."""
types_ = (True,) if connectable else (True, False)
results: list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]] = []
for type_ in types_:
for scanner in self._get_scanners_by_type(type_):
devices_and_adv_data = scanner.discovered_devices_and_advertisement_data
if device_adv_data := devices_and_adv_data.get(address):
results.append((scanner, *device_adv_data))
return results
) -> list[BluetoothScannerDevice]:
"""Get BluetoothScannerDevice by address."""
scanners = self._get_scanners_by_type(True)
if not connectable:
scanners.extend(self._get_scanners_by_type(False))
return [
BluetoothScannerDevice(scanner, *device_adv)
for scanner in scanners
if (
device_adv := scanner.discovered_devices_and_advertisement_data.get(
address
)
)
]
@hass_callback
def _async_all_discovered_addresses(self, connectable: bool) -> Iterable[str]:

View file

@ -12,11 +12,7 @@ from typing import TYPE_CHECKING, Any, Final
from bleak import BleakClient, BleakError
from bleak.backends.client import BaseBleakClient, get_platform_client_backend_type
from bleak.backends.device import BLEDevice
from bleak.backends.scanner import (
AdvertisementData,
AdvertisementDataCallback,
BaseBleakScanner,
)
from bleak.backends.scanner import AdvertisementDataCallback, BaseBleakScanner
from bleak_retry_connector import (
NO_RSSI_VALUE,
ble_device_description,
@ -28,7 +24,7 @@ from homeassistant.core import CALLBACK_TYPE, callback as hass_callback
from homeassistant.helpers.frame import report
from . import models
from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice
FILTER_UUIDS: Final = "UUIDs"
_LOGGER = logging.getLogger(__name__)
@ -149,9 +145,7 @@ class HaBleakScannerWrapper(BaseBleakScanner):
def _rssi_sorter_with_connection_failure_penalty(
scanner_device_advertisement_data: tuple[
BaseHaScanner, BLEDevice, AdvertisementData
],
device: BluetoothScannerDevice,
connection_failure_count: dict[BaseHaScanner, int],
rssi_diff: int,
) -> float:
@ -168,9 +162,8 @@ def _rssi_sorter_with_connection_failure_penalty(
best adapter twice before moving on to the next best adapter since
the first failure may be a transient service resolution issue.
"""
scanner, _, advertisement_data = scanner_device_advertisement_data
base_rssi = advertisement_data.rssi or NO_RSSI_VALUE
if connect_failures := connection_failure_count.get(scanner):
base_rssi = device.advertisement.rssi or NO_RSSI_VALUE
if connect_failures := connection_failure_count.get(device.scanner):
if connect_failures > 1 and not rssi_diff:
rssi_diff = 1
return base_rssi - (rssi_diff * connect_failures * 0.51)
@ -300,14 +293,10 @@ class HaBleakClientWrapper(BleakClient):
that has a free connection slot.
"""
address = self.__address
scanner_device_advertisement_datas = manager.async_get_scanner_discovered_devices_and_advertisement_data_by_address( # noqa: E501
address, True
)
sorted_scanner_device_advertisement_datas = sorted(
scanner_device_advertisement_datas,
key=lambda scanner_device_advertisement_data: (
scanner_device_advertisement_data[2].rssi or NO_RSSI_VALUE
),
devices = manager.async_scanner_devices_by_address(self.__address, True)
sorted_devices = sorted(
devices,
key=lambda device: device.advertisement.rssi or NO_RSSI_VALUE,
reverse=True,
)
@ -315,31 +304,28 @@ class HaBleakClientWrapper(BleakClient):
# to prefer the adapter/scanner with the less failures so
# we don't keep trying to connect with an adapter
# that is failing
if (
self.__connect_failures
and len(sorted_scanner_device_advertisement_datas) > 1
):
if self.__connect_failures and len(sorted_devices) > 1:
# We use the rssi diff between to the top two
# to adjust the rssi sorter so that each failure
# will reduce the rssi sorter by the diff amount
rssi_diff = (
sorted_scanner_device_advertisement_datas[0][2].rssi
- sorted_scanner_device_advertisement_datas[1][2].rssi
sorted_devices[0].advertisement.rssi
- sorted_devices[1].advertisement.rssi
)
adjusted_rssi_sorter = partial(
_rssi_sorter_with_connection_failure_penalty,
connection_failure_count=self.__connect_failures,
rssi_diff=rssi_diff,
)
sorted_scanner_device_advertisement_datas = sorted(
scanner_device_advertisement_datas,
sorted_devices = sorted(
devices,
key=adjusted_rssi_sorter,
reverse=True,
)
for (scanner, ble_device, _) in sorted_scanner_device_advertisement_datas:
for device in sorted_devices:
if backend := self._async_get_backend_for_ble_device(
manager, scanner, ble_device
manager, device.scanner, device.ble_device
):
return backend

View file

@ -1,10 +1,18 @@
"""Tests for the Bluetooth integration API."""
from homeassistant.components import bluetooth
from homeassistant.components.bluetooth import async_scanner_by_source
from bleak.backends.scanner import AdvertisementData, BLEDevice
from . import FakeScanner
from homeassistant.components import bluetooth
from homeassistant.components.bluetooth import (
BaseHaRemoteScanner,
BaseHaScanner,
HaBluetoothConnector,
async_scanner_by_source,
async_scanner_devices_by_address,
)
from . import FakeScanner, MockBleakClient, _get_manager, generate_advertisement_data
async def test_scanner_by_source(hass, enable_bluetooth):
@ -16,3 +24,116 @@ async def test_scanner_by_source(hass, enable_bluetooth):
assert async_scanner_by_source(hass, "hci2") is hci2_scanner
cancel_hci2()
assert async_scanner_by_source(hass, "hci2") is None
async def test_async_scanner_devices_by_address_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with connectable devices."""
manager = _get_manager()
class FakeInjectableScanner(BaseHaRemoteScanner):
def inject_advertisement(
self, device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Inject an advertisement."""
self._async_on_advertisement(
device.address,
advertisement_data.rssi,
device.name,
advertisement_data.service_uuids,
advertisement_data.service_data,
advertisement_data.manufacturer_data,
advertisement_data.tx_power,
{"scanner_specific_data": "test"},
)
new_info_callback = manager.scanner_adv_received
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeInjectableScanner(
hass, "esp32", "esp32", new_info_callback, connector, False
)
unsetup = scanner.async_setup()
cancel = manager.async_register_scanner(scanner, True)
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
scanner.inject_advertisement(switchbot_device, switchbot_device_adv)
assert async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
) == async_scanner_devices_by_address(hass, "44:44:33:11:23:45", connectable=False)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
unsetup()
cancel()
async def test_async_scanner_devices_by_address_non_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with non-connectable devices."""
manager = _get_manager()
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
class FakeStaticScanner(BaseHaScanner):
@property
def discovered_devices(self) -> list[BLEDevice]:
"""Return a list of discovered devices."""
return [switchbot_device]
@property
def discovered_devices_and_advertisement_data(
self,
) -> dict[str, tuple[BLEDevice, AdvertisementData]]:
"""Return a list of discovered devices and their advertisement data."""
return {switchbot_device.address: (switchbot_device, switchbot_device_adv)}
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeStaticScanner(hass, "esp32", "esp32", connector)
cancel = manager.async_register_scanner(scanner, False)
assert scanner.discovered_devices_and_advertisement_data == {
switchbot_device.address: (switchbot_device, switchbot_device_adv)
}
assert (
async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
)
== []
)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
cancel()