Use stable USB device path in USB discovery (#94266)

This commit is contained in:
Erik Montnemery 2023-06-08 18:27:04 +02:00 committed by GitHub
parent 6db1fbf480
commit c8756ba5bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 19 deletions

View file

@ -168,12 +168,9 @@ class InsteonFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
dev_path = await self.hass.async_add_executor_job(
usb.get_serial_by_id, discovery_info.device
)
self._device_path = dev_path
self._device_path = discovery_info.device
self._device_name = usb.human_readable_device_name(
dev_path,
discovery_info.device,
discovery_info.serial_number,
discovery_info.manufacturer,
discovery_info.description,

View file

@ -300,8 +300,7 @@ class USBDiscovery:
return _async_remove_callback
@hass_callback
def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
async def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
"""Process a USB discovery."""
_LOGGER.debug("Discovered USB Device: %s", device)
device_tuple = dataclasses.astuple(device)
@ -313,14 +312,7 @@ class USBDiscovery:
if not matched:
return
service_info = UsbServiceInfo(
device=device.device,
vid=device.vid,
pid=device.pid,
serial_number=device.serial_number,
manufacturer=device.manufacturer,
description=device.description,
)
service_info: UsbServiceInfo | None = None
sorted_by_most_targeted = sorted(matched, key=lambda item: -len(item))
most_matched_fields = len(sorted_by_most_targeted[0])
@ -331,6 +323,18 @@ class USBDiscovery:
if len(matcher) < most_matched_fields:
break
if service_info is None:
service_info = UsbServiceInfo(
device=await self.hass.async_add_executor_job(
get_serial_by_id, device.device
),
vid=device.vid,
pid=device.pid,
serial_number=device.serial_number,
manufacturer=device.manufacturer,
description=device.description,
)
discovery_flow.async_create_flow(
self.hass,
matcher["domain"],
@ -338,17 +342,18 @@ class USBDiscovery:
service_info,
)
@hass_callback
def _async_process_ports(self, ports: list[ListPortInfo]) -> None:
async def _async_process_ports(self, ports: list[ListPortInfo]) -> None:
"""Process each discovered port."""
for port in ports:
if port.vid is None and port.pid is None:
continue
self._async_process_discovered_usb_device(usb_device_from_port(port))
await self._async_process_discovered_usb_device(usb_device_from_port(port))
async def _async_scan_serial(self) -> None:
"""Scan serial ports."""
self._async_process_ports(await self.hass.async_add_executor_job(comports))
await self._async_process_ports(
await self.hass.async_add_executor_job(comports)
)
if self.initial_scan_done:
return

View file

@ -1021,3 +1021,45 @@ async def test_cancel_initial_scan_callback(
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert len(mock_callback.mock_calls) == 0
async def test_resolve_serial_by_id(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the discovery data resolves to serial/by-id."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
mock_comports = [
MagicMock(
device=slae_sh_device.device,
vid=12345,
pid=12345,
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
]
with patch("pyudev.Context", side_effect=ImportError), patch(
"homeassistant.components.usb.async_get_usb", return_value=new_usb
), patch(
"homeassistant.components.usb.comports", return_value=mock_comports
), patch(
"homeassistant.components.usb.get_serial_by_id",
return_value="/dev/serial/by-id/bla",
), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow:
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1
assert mock_config_flow.mock_calls[0][1][0] == "test1"
assert mock_config_flow.mock_calls[0][2]["data"].device == "/dev/serial/by-id/bla"