Switch periodic USB scanning to on-demand websocket when observer is not available (#54953)

This commit is contained in:
J. Nick Koston 2021-08-21 16:06:44 -05:00 committed by GitHub
parent a931e35a14
commit 42f7f19be5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 39 deletions

View file

@ -2,29 +2,31 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import datetime
import logging import logging
import os import os
import sys import sys
from serial.tools.list_ports import comports from serial.tools.list_ports import comports
from serial.tools.list_ports_common import ListPortInfo from serial.tools.list_ports_common import ListPortInfo
import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_usb from homeassistant.loader import async_get_usb
from .const import DOMAIN
from .flow import FlowDispatcher, USBFlow from .flow import FlowDispatcher, USBFlow
from .models import USBDevice from .models import USBDevice
from .utils import usb_device_from_port from .utils import usb_device_from_port
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Perodic scanning only happens on non-linux systems REQUEST_SCAN_COOLDOWN = 60 # 1 minute cooldown
SCAN_INTERVAL = datetime.timedelta(minutes=60)
def human_readable_device_name( def human_readable_device_name(
@ -63,6 +65,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
usb = await async_get_usb(hass) usb = await async_get_usb(hass)
usb_discovery = USBDiscovery(hass, FlowDispatcher(hass), usb) usb_discovery = USBDiscovery(hass, FlowDispatcher(hass), usb)
await usb_discovery.async_setup() await usb_discovery.async_setup()
hass.data[DOMAIN] = usb_discovery
websocket_api.async_register_command(hass, websocket_usb_scan)
return True return True
@ -80,31 +85,23 @@ class USBDiscovery:
self.flow_dispatcher = flow_dispatcher self.flow_dispatcher = flow_dispatcher
self.usb = usb self.usb = usb
self.seen: set[tuple[str, ...]] = set() self.seen: set[tuple[str, ...]] = set()
self.observer_active = False
self._request_debouncer: Debouncer | None = None
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Set up USB Discovery.""" """Set up USB Discovery."""
if not await self._async_start_monitor(): await self._async_start_monitor()
await self._async_start_scanner()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start)
async def async_start(self, event: Event) -> None: async def async_start(self, event: Event) -> None:
"""Start USB Discovery and run a manual scan.""" """Start USB Discovery and run a manual scan."""
self.flow_dispatcher.async_start() self.flow_dispatcher.async_start()
await self.hass.async_add_executor_job(self.scan_serial) await self._async_scan_serial()
async def _async_start_scanner(self) -> None: async def _async_start_monitor(self) -> None:
"""Perodic scan with pyserial when the observer is not available."""
stop_track = async_track_time_interval(
self.hass, lambda now: self.scan_serial(), SCAN_INTERVAL
)
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, callback(lambda event: stop_track())
)
async def _async_start_monitor(self) -> bool:
"""Start monitoring hardware with pyudev.""" """Start monitoring hardware with pyudev."""
if not sys.platform.startswith("linux"): if not sys.platform.startswith("linux"):
return False return
from pyudev import ( # pylint: disable=import-outside-toplevel from pyudev import ( # pylint: disable=import-outside-toplevel
Context, Context,
Monitor, Monitor,
@ -114,7 +111,7 @@ class USBDiscovery:
try: try:
context = Context() context = Context()
except (ImportError, OSError): except (ImportError, OSError):
return False return
monitor = Monitor.from_netlink(context) monitor = Monitor.from_netlink(context)
monitor.filter_by(subsystem="tty") monitor.filter_by(subsystem="tty")
@ -125,7 +122,7 @@ class USBDiscovery:
self.hass.bus.async_listen_once( self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, lambda event: observer.stop() EVENT_HOMEASSISTANT_STOP, lambda event: observer.stop()
) )
return True self.observer_active = True
def _device_discovered(self, device): def _device_discovered(self, device):
"""Call when the observer discovers a new usb tty device.""" """Call when the observer discovers a new usb tty device."""
@ -168,3 +165,34 @@ class USBDiscovery:
def scan_serial(self) -> None: def scan_serial(self) -> None:
"""Scan serial ports.""" """Scan serial ports."""
self.hass.add_job(self._async_process_ports, comports()) self.hass.add_job(self._async_process_ports, comports())
async def _async_scan_serial(self) -> None:
"""Scan serial ports."""
self._async_process_ports(await self.hass.async_add_executor_job(comports))
async def async_request_scan_serial(self) -> None:
"""Request a serial scan."""
if not self._request_debouncer:
self._request_debouncer = Debouncer(
self.hass,
_LOGGER,
cooldown=REQUEST_SCAN_COOLDOWN,
immediate=True,
function=self._async_scan_serial,
)
await self._request_debouncer.async_call()
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "usb/scan"})
@websocket_api.async_response
async def websocket_usb_scan(
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict,
) -> None:
"""Scan for new usb devices."""
usb_discovery: USBDiscovery = hass.data[DOMAIN]
if not usb_discovery.observer_active:
await usb_discovery.async_request_scan_serial()
connection.send_result(msg["id"])

View file

@ -7,6 +7,7 @@
"pyserial==3.5" "pyserial==3.5"
], ],
"codeowners": ["@bdraco"], "codeowners": ["@bdraco"],
"dependencies": ["websocket_api"],
"quality_scale": "internal", "quality_scale": "internal",
"iot_class": "local_push" "iot_class": "local_push"
} }

View file

@ -1,5 +1,4 @@
"""Tests for the USB Discovery integration.""" """Tests for the USB Discovery integration."""
import datetime
import os import os
import sys import sys
from unittest.mock import MagicMock, patch, sentinel from unittest.mock import MagicMock, patch, sentinel
@ -9,12 +8,9 @@ import pytest
from homeassistant.components import usb from homeassistant.components import usb
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from . import slae_sh_device from . import slae_sh_device
from tests.common import async_fire_time_changed
@pytest.mark.skipif( @pytest.mark.skipif(
not sys.platform.startswith("linux"), not sys.platform.startswith("linux"),
@ -113,8 +109,8 @@ async def test_removal_by_observer_before_started(hass):
assert len(mock_config_flow.mock_calls) == 0 assert len(mock_config_flow.mock_calls) == 0
async def test_discovered_by_scanner_after_started(hass): async def test_discovered_by_websocket_scan(hass, hass_ws_client):
"""Test a device is discovered by the scanner after the started event.""" """Test a device is discovered from websocket scan."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}] new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
mock_comports = [ mock_comports = [
@ -139,15 +135,18 @@ async def test_discovered_by_scanner_after_started(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=1)) ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1 assert len(mock_config_flow.mock_calls) == 1
assert mock_config_flow.mock_calls[0][1][0] == "test1" assert mock_config_flow.mock_calls[0][1][0] == "test1"
async def test_discovered_by_scanner_after_started_match_vid_only(hass): async def test_discovered_by_websocket_scan_match_vid_only(hass, hass_ws_client):
"""Test a device is discovered by the scanner after the started event only matching vid.""" """Test a device is discovered from websocket scan only matching vid."""
new_usb = [{"domain": "test1", "vid": "3039"}] new_usb = [{"domain": "test1", "vid": "3039"}]
mock_comports = [ mock_comports = [
@ -172,15 +171,18 @@ async def test_discovered_by_scanner_after_started_match_vid_only(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=1)) ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1 assert len(mock_config_flow.mock_calls) == 1
assert mock_config_flow.mock_calls[0][1][0] == "test1" assert mock_config_flow.mock_calls[0][1][0] == "test1"
async def test_discovered_by_scanner_after_started_match_vid_wrong_pid(hass): async def test_discovered_by_websocket_scan_match_vid_wrong_pid(hass, hass_ws_client):
"""Test a device is discovered by the scanner after the started event only matching vid but wrong pid.""" """Test a device is discovered from websocket scan only matching vid but wrong pid."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "9999"}] new_usb = [{"domain": "test1", "vid": "3039", "pid": "9999"}]
mock_comports = [ mock_comports = [
@ -205,14 +207,17 @@ async def test_discovered_by_scanner_after_started_match_vid_wrong_pid(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=1)) ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0 assert len(mock_config_flow.mock_calls) == 0
async def test_discovered_by_scanner_after_started_no_vid_pid(hass): async def test_discovered_by_websocket_no_vid_pid(hass, hass_ws_client):
"""Test a device is discovered by the scanner after the started event with no vid or pid.""" """Test a device is discovered from websocket scan with no vid or pid."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "9999"}] new_usb = [{"domain": "test1", "vid": "3039", "pid": "9999"}]
mock_comports = [ mock_comports = [
@ -237,15 +242,20 @@ async def test_discovered_by_scanner_after_started_no_vid_pid(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=1)) ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0 assert len(mock_config_flow.mock_calls) == 0
@pytest.mark.parametrize("exception_type", [ImportError, OSError]) @pytest.mark.parametrize("exception_type", [ImportError, OSError])
async def test_non_matching_discovered_by_scanner_after_started(hass, exception_type): async def test_non_matching_discovered_by_scanner_after_started(
"""Test a device is discovered by the scanner after the started event that does not match.""" hass, exception_type, hass_ws_client
):
"""Test a websocket scan that does not match."""
new_usb = [{"domain": "test1", "vid": "4444", "pid": "4444"}] new_usb = [{"domain": "test1", "vid": "4444", "pid": "4444"}]
mock_comports = [ mock_comports = [
@ -270,7 +280,10 @@ async def test_non_matching_discovered_by_scanner_after_started(hass, exception_
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
async_fire_time_changed(hass, dt_util.utcnow() + datetime.timedelta(hours=1)) ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0 assert len(mock_config_flow.mock_calls) == 0