Discover new bluetooth adapters when they are plugged in (#77006)

This commit is contained in:
J. Nick Koston 2022-08-22 15:45:08 -10:00 committed by GitHub
parent 325557c3e9
commit c76dec138a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 230 additions and 47 deletions

View file

@ -3,16 +3,19 @@ from __future__ import annotations
from asyncio import Future from asyncio import Future
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import logging
import platform import platform
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
import async_timeout import async_timeout
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import usb
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr, discovery_flow from homeassistant.helpers import device_registry as dr, discovery_flow
from homeassistant.helpers.debounce import Debouncer
from homeassistant.loader import async_get_bluetooth from homeassistant.loader import async_get_bluetooth
from . import models from . import models
@ -65,6 +68,8 @@ __all__ = [
"SOURCE_LOCAL", "SOURCE_LOCAL",
] ]
_LOGGER = logging.getLogger(__name__)
def _get_manager(hass: HomeAssistant) -> BluetoothManager: def _get_manager(hass: HomeAssistant) -> BluetoothManager:
"""Get the bluetooth manager.""" """Get the bluetooth manager."""
@ -214,6 +219,31 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_migrate_entries(hass, adapters) async_migrate_entries(hass, adapters)
await async_discover_adapters(hass, adapters) await async_discover_adapters(hass, adapters)
async def _async_rediscover_adapters() -> None:
"""Rediscover adapters when a new one may be available."""
discovered_adapters = await manager.async_get_bluetooth_adapters(cached=False)
_LOGGER.debug("Rediscovered adapters: %s", discovered_adapters)
await async_discover_adapters(hass, discovered_adapters)
discovery_debouncer = Debouncer(
hass, _LOGGER, cooldown=5, immediate=False, function=_async_rediscover_adapters
)
def _async_trigger_discovery() -> None:
# There are so many bluetooth adapter models that
# we check the bus whenever a usb device is plugged in
# to see if it is a bluetooth adapter since we can't
# tell if the device is a bluetooth adapter or if its
# actually supported unless we ask DBus if its now
# present.
_LOGGER.debug("Triggering bluetooth usb discovery")
hass.async_create_task(discovery_debouncer.async_call())
cancel = usb.async_register_scan_request_callback(hass, _async_trigger_discovery)
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, hass_callback(lambda event: cancel())
)
return True return True

View file

@ -2,7 +2,7 @@
"domain": "bluetooth", "domain": "bluetooth",
"name": "Bluetooth", "name": "Bluetooth",
"documentation": "https://www.home-assistant.io/integrations/bluetooth", "documentation": "https://www.home-assistant.io/integrations/bluetooth",
"dependencies": ["websocket_api"], "dependencies": ["usb"],
"quality_scale": "internal", "quality_scale": "internal",
"requirements": [ "requirements": [
"bleak==0.15.1", "bleak==0.15.1",

View file

@ -1,6 +1,8 @@
"""The Home Assistant Sky Connect integration.""" """The Home Assistant Sky Connect integration."""
from __future__ import annotations from __future__ import annotations
from typing import cast
from homeassistant.components import usb from homeassistant.components import usb
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -17,7 +19,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
manufacturer=entry.data["manufacturer"], manufacturer=entry.data["manufacturer"],
description=entry.data["description"], description=entry.data["description"],
) )
if not usb.async_is_plugged_in(hass, entry.data): if not usb.async_is_plugged_in(hass, cast(usb.USBCallbackMatcher, entry.data)):
# The USB dongle is not plugged in # The USB dongle is not plugged in
raise ConfigEntryNotReady raise ConfigEntryNotReady

View file

@ -1,7 +1,7 @@
"""The USB Discovery integration.""" """The USB Discovery integration."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Coroutine, Mapping from collections.abc import Coroutine
import dataclasses import dataclasses
import fnmatch import fnmatch
import logging import logging
@ -17,12 +17,17 @@ from homeassistant import config_entries
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
callback as hass_callback,
)
from homeassistant.data_entry_flow import BaseServiceInfo from homeassistant.data_entry_flow import BaseServiceInfo
from homeassistant.helpers import discovery_flow, system_info from homeassistant.helpers import discovery_flow, system_info
from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_usb from homeassistant.loader import USBMatcher, async_get_usb
from .const import DOMAIN from .const import DOMAIN
from .models import USBDevice from .models import USBDevice
@ -35,6 +40,36 @@ _LOGGER = logging.getLogger(__name__)
REQUEST_SCAN_COOLDOWN = 60 # 1 minute cooldown REQUEST_SCAN_COOLDOWN = 60 # 1 minute cooldown
__all__ = [
"async_is_plugged_in",
"async_register_scan_request_callback",
"USBCallbackMatcher",
"UsbServiceInfo",
]
class USBCallbackMatcher(USBMatcher):
"""Callback matcher for the USB integration."""
@hass_callback
def async_register_scan_request_callback(
hass: HomeAssistant, callback: CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Register to receive a callback when a scan should be initiated."""
discovery: USBDiscovery = hass.data[DOMAIN]
return discovery.async_register_scan_request_callback(callback)
@hass_callback
def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool:
"""Return True is a USB device is present."""
usb_discovery: USBDiscovery = hass.data[DOMAIN]
return any(
_is_matching(USBDevice(*device_tuple), matcher)
for device_tuple in usb_discovery.seen
)
@dataclasses.dataclass @dataclasses.dataclass
class UsbServiceInfo(BaseServiceInfo): class UsbServiceInfo(BaseServiceInfo):
@ -97,7 +132,7 @@ def _fnmatch_lower(name: str | None, pattern: str) -> bool:
return fnmatch.fnmatch(name.lower(), pattern) return fnmatch.fnmatch(name.lower(), pattern)
def _is_matching(device: USBDevice, matcher: Mapping[str, str]) -> bool: def _is_matching(device: USBDevice, matcher: USBMatcher | USBCallbackMatcher) -> bool:
"""Return True if a device matches.""" """Return True if a device matches."""
if "vid" in matcher and device.vid != matcher["vid"]: if "vid" in matcher and device.vid != matcher["vid"]:
return False return False
@ -124,7 +159,7 @@ class USBDiscovery:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
usb: list[dict[str, str]], usb: list[USBMatcher],
) -> None: ) -> None:
"""Init USB Discovery.""" """Init USB Discovery."""
self.hass = hass self.hass = hass
@ -132,6 +167,7 @@ class USBDiscovery:
self.seen: set[tuple[str, ...]] = set() self.seen: set[tuple[str, ...]] = set()
self.observer_active = False self.observer_active = False
self._request_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None self._request_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
self._request_callbacks: list[CALLBACK_TYPE] = []
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Set up USB Discovery.""" """Set up USB Discovery."""
@ -188,9 +224,23 @@ class USBDiscovery:
"Discovered Device at path: %s, triggering scan serial", "Discovered Device at path: %s, triggering scan serial",
device.device_path, device.device_path,
) )
self.scan_serial() self.hass.create_task(self._async_scan())
@callback @hass_callback
def async_register_scan_request_callback(
self,
_callback: CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Register a callback."""
self._request_callbacks.append(_callback)
@hass_callback
def _async_remove_callback() -> None:
self._request_callbacks.remove(_callback)
return _async_remove_callback
@hass_callback
def _async_process_discovered_usb_device(self, device: USBDevice) -> None: def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
"""Process a USB discovery.""" """Process a USB discovery."""
_LOGGER.debug("Discovered USB Device: %s", device) _LOGGER.debug("Discovered USB Device: %s", device)
@ -198,14 +248,20 @@ class USBDiscovery:
if device_tuple in self.seen: if device_tuple in self.seen:
return return
self.seen.add(device_tuple) self.seen.add(device_tuple)
matched = []
for matcher in self.usb:
if _is_matching(device, matcher):
matched.append(matcher)
matched = [matcher for matcher in self.usb if _is_matching(device, matcher)]
if not matched: if not matched:
return return
service_info = UsbServiceInfo(
device=device.device,
vid=device.vid,
pid=device.pid,
serial_number=device.serial_number,
manufacturer=device.manufacturer,
description=device.description,
)
sorted_by_most_targeted = sorted(matched, key=lambda item: -len(item)) sorted_by_most_targeted = sorted(matched, key=lambda item: -len(item))
most_matched_fields = len(sorted_by_most_targeted[0]) most_matched_fields = len(sorted_by_most_targeted[0])
@ -219,17 +275,10 @@ class USBDiscovery:
self.hass, self.hass,
matcher["domain"], matcher["domain"],
{"source": config_entries.SOURCE_USB}, {"source": config_entries.SOURCE_USB},
UsbServiceInfo( service_info,
device=device.device,
vid=device.vid,
pid=device.pid,
serial_number=device.serial_number,
manufacturer=device.manufacturer,
description=device.description,
),
) )
@callback @hass_callback
def _async_process_ports(self, ports: list[ListPortInfo]) -> None: def _async_process_ports(self, ports: list[ListPortInfo]) -> None:
"""Process each discovered port.""" """Process each discovered port."""
for port in ports: for port in ports:
@ -237,15 +286,17 @@ class USBDiscovery:
continue continue
self._async_process_discovered_usb_device(usb_device_from_port(port)) self._async_process_discovered_usb_device(usb_device_from_port(port))
def scan_serial(self) -> None:
"""Scan serial ports."""
self.hass.add_job(self._async_process_ports, comports())
async def _async_scan_serial(self) -> None: async def _async_scan_serial(self) -> None:
"""Scan serial ports.""" """Scan serial ports."""
self._async_process_ports(await self.hass.async_add_executor_job(comports)) self._async_process_ports(await self.hass.async_add_executor_job(comports))
async def async_request_scan_serial(self) -> None: async def _async_scan(self) -> None:
"""Scan for USB devices and notify callbacks to scan as well."""
for callback in self._request_callbacks:
callback()
await self._async_scan_serial()
async def async_request_scan(self) -> None:
"""Request a serial scan.""" """Request a serial scan."""
if not self._request_debouncer: if not self._request_debouncer:
self._request_debouncer = Debouncer( self._request_debouncer = Debouncer(
@ -253,7 +304,7 @@ class USBDiscovery:
_LOGGER, _LOGGER,
cooldown=REQUEST_SCAN_COOLDOWN, cooldown=REQUEST_SCAN_COOLDOWN,
immediate=True, immediate=True,
function=self._async_scan_serial, function=self._async_scan,
) )
await self._request_debouncer.async_call() await self._request_debouncer.async_call()
@ -269,16 +320,5 @@ async def websocket_usb_scan(
"""Scan for new usb devices.""" """Scan for new usb devices."""
usb_discovery: USBDiscovery = hass.data[DOMAIN] usb_discovery: USBDiscovery = hass.data[DOMAIN]
if not usb_discovery.observer_active: if not usb_discovery.observer_active:
await usb_discovery.async_request_scan_serial() await usb_discovery.async_request_scan()
connection.send_result(msg["id"]) connection.send_result(msg["id"])
@callback
def async_is_plugged_in(hass: HomeAssistant, matcher: Mapping) -> bool:
"""Return True is a USB device is present."""
usb_discovery: USBDiscovery = hass.data[DOMAIN]
for device_tuple in usb_discovery.seen:
device = USBDevice(*device_tuple)
if _is_matching(device, matcher):
return True
return False

View file

@ -98,6 +98,26 @@ class BluetoothMatcher(BluetoothMatcherRequired, BluetoothMatcherOptional):
"""Matcher for the bluetooth integration.""" """Matcher for the bluetooth integration."""
class USBMatcherRequired(TypedDict, total=True):
"""Matcher for the usb integration for required fields."""
domain: str
class USBMatcherOptional(TypedDict, total=False):
"""Matcher for the usb integration for optional fields."""
vid: str
pid: str
serial_number: str
manufacturer: str
description: str
class USBMatcher(USBMatcherRequired, USBMatcherOptional):
"""Matcher for the bluetooth integration."""
class Manifest(TypedDict, total=False): class Manifest(TypedDict, total=False):
""" """
Integration manifest. Integration manifest.
@ -318,9 +338,9 @@ async def async_get_dhcp(hass: HomeAssistant) -> list[DHCPMatcher]:
return dhcp return dhcp
async def async_get_usb(hass: HomeAssistant) -> list[dict[str, str]]: async def async_get_usb(hass: HomeAssistant) -> list[USBMatcher]:
"""Return cached list of usb types.""" """Return cached list of usb types."""
usb: list[dict[str, str]] = USB.copy() usb = cast(list[USBMatcher], USB.copy())
integrations = await async_get_custom_components(hass) integrations = await async_get_custom_components(hass)
for integration in integrations.values(): for integration in integrations.values():
@ -328,10 +348,13 @@ async def async_get_usb(hass: HomeAssistant) -> list[dict[str, str]]:
continue continue
for entry in integration.usb: for entry in integration.usb:
usb.append( usb.append(
{ cast(
"domain": integration.domain, USBMatcher,
**{k: v for k, v in entry.items() if k != "known_devices"}, {
} "domain": integration.domain,
**{k: v for k, v in entry.items() if k != "known_devices"},
},
)
) )
return usb return usb

View file

@ -1641,3 +1641,59 @@ async def test_migrate_single_entry_linux(hass, mock_bleak_scanner_start, one_ad
assert await async_setup_component(hass, bluetooth.DOMAIN, {}) assert await async_setup_component(hass, bluetooth.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.unique_id == "00:00:00:00:00:01" assert entry.unique_id == "00:00:00:00:00:01"
async def test_discover_new_usb_adapters(hass, mock_bleak_scanner_start, one_adapter):
"""Test we can discover new usb adapters."""
entry = MockConfigEntry(
domain=bluetooth.DOMAIN, data={}, unique_id="00:00:00:00:00:01"
)
entry.add_to_hass(hass)
saved_callback = None
def _async_register_scan_request_callback(_hass, _callback):
nonlocal saved_callback
saved_callback = _callback
return lambda: None
with patch(
"homeassistant.components.bluetooth.usb.async_register_scan_request_callback",
_async_register_scan_request_callback,
):
assert await async_setup_component(hass, bluetooth.DOMAIN, {})
await hass.async_block_till_done()
assert not hass.config_entries.flow.async_progress(DOMAIN)
saved_callback()
assert not hass.config_entries.flow.async_progress(DOMAIN)
with patch(
"homeassistant.components.bluetooth.util.platform.system", return_value="Linux"
), patch(
"bluetooth_adapters.get_bluetooth_adapter_details",
return_value={
"hci0": {
"org.bluez.Adapter1": {
"Address": "00:00:00:00:00:01",
"Name": "BlueZ 4.63",
"Modalias": "usbid:1234",
}
},
"hci1": {
"org.bluez.Adapter1": {
"Address": "00:00:00:00:00:02",
"Name": "BlueZ 4.63",
"Modalias": "usbid:1234",
}
},
},
):
for wait_sec in range(10, 20):
async_fire_time_changed(
hass, dt_util.utcnow() + timedelta(seconds=wait_sec)
)
await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress(DOMAIN)) == 1

View file

@ -1,7 +1,7 @@
"""Tests for the USB Discovery integration.""" """Tests for the USB Discovery integration."""
import os import os
import sys import sys
from unittest.mock import MagicMock, call, patch, sentinel from unittest.mock import MagicMock, Mock, call, patch, sentinel
import pytest import pytest
@ -875,3 +875,35 @@ async def test_async_is_plugged_in(hass, hass_ws_client):
assert response["success"] assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert usb.async_is_plugged_in(hass, matcher) assert usb.async_is_plugged_in(hass, matcher)
async def test_web_socket_triggers_discovery_request_callbacks(hass, hass_ws_client):
"""Test the websocket call triggers a discovery request callback."""
mock_callback = Mock()
with patch("pyudev.Context", side_effect=ImportError), patch(
"homeassistant.components.usb.async_get_usb", return_value=[]
), patch("homeassistant.components.usb.comports", return_value=[]), patch.object(
hass.config_entries.flow, "async_init"
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
cancel = usb.async_register_scan_request_callback(hass, mock_callback)
ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert len(mock_callback.mock_calls) == 1
cancel()
await ws_client.send_json({"id": 2, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert len(mock_callback.mock_calls) == 1