From c00a5fad8fecb305834bc8e962117f8f245adb1e Mon Sep 17 00:00:00 2001
From: epenet <6771947+epenet@users.noreply.github.com>
Date: Fri, 22 Oct 2021 11:45:40 +0200
Subject: [PATCH] Cleanup device registration in Onewire (#58101)

* Add checks for device registry

* Move registry checks to init.py

* Run device registry check on disabled devices

* Empty commit for testing

* Register devices during initialisation

* Adjust tests accordingly

* Add via_device to device info

* Adjust access to device registry

Co-authored-by: epenet <epenet@users.noreply.github.com>
---
 homeassistant/components/onewire/__init__.py  |  36 ++---
 .../components/onewire/binary_sensor.py       |  30 ++---
 homeassistant/components/onewire/model.py     |  26 +++-
 .../components/onewire/onewirehub.py          | 112 +++++++++++++---
 homeassistant/components/onewire/sensor.py    |  41 +++---
 homeassistant/components/onewire/switch.py    |  31 ++---
 tests/components/onewire/__init__.py          | 124 +++++++++++++++---
 tests/components/onewire/const.py             |  27 ++--
 .../components/onewire/test_binary_sensor.py  |  10 +-
 tests/components/onewire/test_sensor.py       | 123 +++--------------
 tests/components/onewire/test_switch.py       |  10 +-
 11 files changed, 317 insertions(+), 253 deletions(-)

diff --git a/homeassistant/components/onewire/__init__.py b/homeassistant/components/onewire/__init__.py
index b99f095de7b..5981a654820 100644
--- a/homeassistant/components/onewire/__init__.py
+++ b/homeassistant/components/onewire/__init__.py
@@ -5,7 +5,7 @@ import logging
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.core import HomeAssistant
 from homeassistant.exceptions import ConfigEntryNotReady
-from homeassistant.helpers import device_registry as dr, entity_registry as er
+from homeassistant.helpers import device_registry as dr
 
 from .const import DOMAIN, PLATFORMS
 from .onewirehub import CannotConnect, OneWireHub
@@ -25,31 +25,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
 
     hass.data[DOMAIN][entry.entry_id] = onewirehub
 
-    async def cleanup_registry() -> None:
+    async def cleanup_registry(onewirehub: OneWireHub) -> None:
         # Get registries
-        device_registry, entity_registry = await asyncio.gather(
-            hass.helpers.device_registry.async_get_registry(),
-            hass.helpers.entity_registry.async_get_registry(),
-        )
+        device_registry = dr.async_get(hass)
         # Generate list of all device entries
-        registry_devices = [
-            entry.id
-            for entry in dr.async_entries_for_config_entry(
-                device_registry, entry.entry_id
-            )
-        ]
+        registry_devices = list(
+            dr.async_entries_for_config_entry(device_registry, entry.entry_id)
+        )
         # Remove devices that don't belong to any entity
-        for device_id in registry_devices:
-            if not er.async_entries_for_device(
-                entity_registry, device_id, include_disabled_entities=True
-            ):
+        for device in registry_devices:
+            if not onewirehub.has_device_in_cache(device):
                 _LOGGER.debug(
-                    "Removing device `%s` because it does not have any entities",
-                    device_id,
+                    "Removing device `%s` because it is no longer available",
+                    device.id,
                 )
-                device_registry.async_remove_device(device_id)
+                device_registry.async_remove_device(device.id)
 
-    async def start_platforms() -> None:
+    async def start_platforms(onewirehub: OneWireHub) -> None:
         """Start platforms and cleanup devices."""
         # wait until all required platforms are ready
         await asyncio.gather(
@@ -58,9 +50,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
                 for platform in PLATFORMS
             )
         )
-        await cleanup_registry()
+        await cleanup_registry(onewirehub)
 
-    hass.async_create_task(start_platforms())
+    hass.async_create_task(start_platforms(onewirehub))
 
     return True
 
diff --git a/homeassistant/components/onewire/binary_sensor.py b/homeassistant/components/onewire/binary_sensor.py
index 0a57e0c1b19..7f569b150c2 100644
--- a/homeassistant/components/onewire/binary_sensor.py
+++ b/homeassistant/components/onewire/binary_sensor.py
@@ -3,21 +3,16 @@ from __future__ import annotations
 
 from dataclasses import dataclass
 import os
+from typing import TYPE_CHECKING
 
 from homeassistant.components.binary_sensor import (
     BinarySensorEntity,
     BinarySensorEntityDescription,
 )
+from homeassistant.components.onewire.model import OWServerDeviceDescription
 from homeassistant.config_entries import ConfigEntry
-from homeassistant.const import (
-    ATTR_IDENTIFIERS,
-    ATTR_MANUFACTURER,
-    ATTR_MODEL,
-    ATTR_NAME,
-    CONF_TYPE,
-)
+from homeassistant.const import CONF_TYPE
 from homeassistant.core import HomeAssistant
-from homeassistant.helpers.entity import DeviceInfo
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
 
 from .const import (
@@ -89,24 +84,17 @@ def get_entities(onewirehub: OneWireHub) -> list[BinarySensorEntity]:
         return []
 
     entities: list[BinarySensorEntity] = []
-
     for device in onewirehub.devices:
-        family = device["family"]
-        device_type = device["type"]
-        device_id = os.path.split(os.path.split(device["path"])[0])[1]
+        if TYPE_CHECKING:
+            assert isinstance(device, OWServerDeviceDescription)
+        family = device.family
+        device_id = device.id
+        device_info = device.device_info
 
         if family not in DEVICE_BINARY_SENSORS:
             continue
-        device_info: DeviceInfo = {
-            ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
-            ATTR_MANUFACTURER: "Maxim Integrated",
-            ATTR_MODEL: device_type,
-            ATTR_NAME: device_id,
-        }
         for description in DEVICE_BINARY_SENSORS[family]:
-            device_file = os.path.join(
-                os.path.split(device["path"])[0], description.key
-            )
+            device_file = os.path.join(os.path.split(device.path)[0], description.key)
             name = f"{device_id} {description.name}"
             entities.append(
                 OneWireProxyBinarySensor(
diff --git a/homeassistant/components/onewire/model.py b/homeassistant/components/onewire/model.py
index 2aaef861a50..370b26c2530 100644
--- a/homeassistant/components/onewire/model.py
+++ b/homeassistant/components/onewire/model.py
@@ -1,12 +1,32 @@
 """Type definitions for 1-Wire integration."""
 from __future__ import annotations
 
-from typing import TypedDict
+from dataclasses import dataclass
+
+from pi1wire import OneWireInterface
+
+from homeassistant.helpers.entity import DeviceInfo
 
 
-class OWServerDeviceDescription(TypedDict):
+@dataclass
+class OWDeviceDescription:
+    """OWDeviceDescription device description class."""
+
+    device_info: DeviceInfo
+
+
+@dataclass
+class OWDirectDeviceDescription(OWDeviceDescription):
+    """SysBus device description class."""
+
+    interface: OneWireInterface
+
+
+@dataclass
+class OWServerDeviceDescription(OWDeviceDescription):
     """OWServer device description class."""
 
-    path: str
     family: str
+    id: str
+    path: str
     type: str
diff --git a/homeassistant/components/onewire/onewirehub.py b/homeassistant/components/onewire/onewirehub.py
index d3b6773e74e..0d7f85fee68 100644
--- a/homeassistant/components/onewire/onewirehub.py
+++ b/homeassistant/components/onewire/onewirehub.py
@@ -1,24 +1,43 @@
 """Hub for communication with 1-Wire server or mount_dir."""
 from __future__ import annotations
 
+import logging
 import os
+from typing import TYPE_CHECKING
 
 from pi1wire import Pi1Wire
 from pyownet import protocol
 
 from homeassistant.config_entries import ConfigEntry
-from homeassistant.const import CONF_HOST, CONF_PORT, CONF_TYPE
+from homeassistant.const import (
+    ATTR_IDENTIFIERS,
+    ATTR_MANUFACTURER,
+    ATTR_MODEL,
+    ATTR_NAME,
+    CONF_HOST,
+    CONF_PORT,
+    CONF_TYPE,
+)
 from homeassistant.core import HomeAssistant
 from homeassistant.exceptions import HomeAssistantError
+from homeassistant.helpers import device_registry as dr
+from homeassistant.helpers.device_registry import DeviceEntry
+from homeassistant.helpers.entity import DeviceInfo
 
-from .const import CONF_MOUNT_DIR, CONF_TYPE_OWSERVER, CONF_TYPE_SYSBUS
-from .model import OWServerDeviceDescription
+from .const import CONF_MOUNT_DIR, CONF_TYPE_OWSERVER, CONF_TYPE_SYSBUS, DOMAIN
+from .model import (
+    OWDeviceDescription,
+    OWDirectDeviceDescription,
+    OWServerDeviceDescription,
+)
 
 DEVICE_COUPLERS = {
     # Family : [branches]
     "1F": ["aux", "main"]
 }
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class OneWireHub:
     """Hub to communicate with SysBus or OWServer."""
@@ -29,7 +48,7 @@ class OneWireHub:
         self.type: str | None = None
         self.pi1proxy: Pi1Wire | None = None
         self.owproxy: protocol._Proxy | None = None
-        self.devices: list | None = None
+        self.devices: list[OWDeviceDescription] | None = None
 
     async def connect(self, host: str, port: int) -> None:
         """Connect to the owserver host."""
@@ -56,42 +75,99 @@ class OneWireHub:
             port = config_entry.data[CONF_PORT]
             await self.connect(host, port)
         await self.discover_devices()
+        if TYPE_CHECKING:
+            assert self.devices
+        # Register discovered devices on Hub
+        device_registry = dr.async_get(self.hass)
+        for device in self.devices:
+            device_info: DeviceInfo = device.device_info
+            device_registry.async_get_or_create(
+                config_entry_id=config_entry.entry_id,
+                identifiers=device_info[ATTR_IDENTIFIERS],
+                manufacturer=device_info[ATTR_MANUFACTURER],
+                model=device_info[ATTR_MODEL],
+                name=device_info[ATTR_NAME],
+                via_device=device_info.get("via_device"),
+            )
 
     async def discover_devices(self) -> None:
         """Discover all devices."""
         if self.devices is None:
             if self.type == CONF_TYPE_SYSBUS:
-                assert self.pi1proxy
                 self.devices = await self.hass.async_add_executor_job(
-                    self.pi1proxy.find_all_sensors
+                    self._discover_devices_sysbus
                 )
             if self.type == CONF_TYPE_OWSERVER:
                 self.devices = await self.hass.async_add_executor_job(
                     self._discover_devices_owserver
                 )
 
+    def _discover_devices_sysbus(self) -> list[OWDeviceDescription]:
+        """Discover all sysbus devices."""
+        devices: list[OWDeviceDescription] = []
+        assert self.pi1proxy
+        for interface in self.pi1proxy.find_all_sensors():
+            family = interface.mac_address[:2]
+            device_id = f"{family}-{interface.mac_address[2:]}"
+            device_info: DeviceInfo = {
+                ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
+                ATTR_MANUFACTURER: "Maxim Integrated",
+                ATTR_MODEL: family,
+                ATTR_NAME: device_id,
+            }
+            device = OWDirectDeviceDescription(
+                device_info=device_info,
+                interface=interface,
+            )
+            devices.append(device)
+        return devices
+
     def _discover_devices_owserver(
-        self, path: str = "/"
-    ) -> list[OWServerDeviceDescription]:
+        self, path: str = "/", parent_id: str | None = None
+    ) -> list[OWDeviceDescription]:
         """Discover all owserver devices."""
-        devices = []
+        devices: list[OWDeviceDescription] = []
         assert self.owproxy
         for device_path in self.owproxy.dir(path):
+            device_id = os.path.split(os.path.split(device_path)[0])[1]
             device_family = self.owproxy.read(f"{device_path}family").decode()
+            _LOGGER.debug("read `%sfamily`: %s", device_path, device_family)
             device_type = self.owproxy.read(f"{device_path}type").decode()
+            _LOGGER.debug("read `%stype`: %s", device_path, device_type)
+            device_info: DeviceInfo = {
+                ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
+                ATTR_MANUFACTURER: "Maxim Integrated",
+                ATTR_MODEL: device_type,
+                ATTR_NAME: device_id,
+            }
+            if parent_id:
+                device_info["via_device"] = (DOMAIN, parent_id)
+            device = OWServerDeviceDescription(
+                device_info=device_info,
+                id=device_id,
+                family=device_family,
+                path=device_path,
+                type=device_type,
+            )
+            devices.append(device)
             if device_branches := DEVICE_COUPLERS.get(device_family):
                 for branch in device_branches:
-                    devices += self._discover_devices_owserver(f"{device_path}{branch}")
-            else:
-                devices.append(
-                    {
-                        "path": device_path,
-                        "family": device_family,
-                        "type": device_type,
-                    }
-                )
+                    devices += self._discover_devices_owserver(
+                        f"{device_path}{branch}", device_id
+                    )
+
         return devices
 
+    def has_device_in_cache(self, device: DeviceEntry) -> bool:
+        """Check if device was present in the cache."""
+        if TYPE_CHECKING:
+            assert self.devices
+        for internal_device in self.devices:
+            for identifier in internal_device.device_info[ATTR_IDENTIFIERS]:
+                if identifier in device.identifiers:
+                    return True
+        return False
+
 
 class CannotConnect(HomeAssistantError):
     """Error to indicate we cannot connect."""
diff --git a/homeassistant/components/onewire/sensor.py b/homeassistant/components/onewire/sensor.py
index b1f08b864b1..c13dcd4ebf3 100644
--- a/homeassistant/components/onewire/sensor.py
+++ b/homeassistant/components/onewire/sensor.py
@@ -7,10 +7,14 @@ from dataclasses import dataclass
 import logging
 import os
 from types import MappingProxyType
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from pi1wire import InvalidCRCException, OneWireInterface, UnsupportResponseException
 
+from homeassistant.components.onewire.model import (
+    OWDirectDeviceDescription,
+    OWServerDeviceDescription,
+)
 from homeassistant.components.sensor import (
     STATE_CLASS_MEASUREMENT,
     STATE_CLASS_TOTAL_INCREASING,
@@ -19,10 +23,6 @@ from homeassistant.components.sensor import (
 )
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.const import (
-    ATTR_IDENTIFIERS,
-    ATTR_MANUFACTURER,
-    ATTR_MODEL,
-    ATTR_NAME,
     CONF_TYPE,
     DEVICE_CLASS_CURRENT,
     DEVICE_CLASS_HUMIDITY,
@@ -382,11 +382,14 @@ def get_entities(
     if conf_type == CONF_TYPE_OWSERVER:
         assert onewirehub.owproxy
         for device in onewirehub.devices:
-            family = device["family"]
-            device_type = device["type"]
-            device_id = os.path.split(os.path.split(device["path"])[0])[1]
+            if TYPE_CHECKING:
+                assert isinstance(device, OWServerDeviceDescription)
+            family = device.family
+            device_type = device.type
+            device_id = device.id
+            device_info = device.device_info
             device_sub_type = "std"
-            device_path = device["path"]
+            device_path = device.path
             if "EF" in family:
                 device_sub_type = "HobbyBoard"
                 family = device_type
@@ -401,12 +404,6 @@ def get_entities(
                     device_id,
                 )
                 continue
-            device_info: DeviceInfo = {
-                ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
-                ATTR_MANUFACTURER: "Maxim Integrated",
-                ATTR_MODEL: device_type,
-                ATTR_NAME: device_id,
-            }
             for description in get_sensor_types(device_sub_type)[family]:
                 if description.key.startswith("moisture/"):
                     s_id = description.key.split(".")[1]
@@ -421,7 +418,7 @@ def get_entities(
                         description.native_unit_of_measurement = PERCENTAGE
                         description.name = f"Wetness {s_id}"
                 device_file = os.path.join(
-                    os.path.split(device["path"])[0], description.key
+                    os.path.split(device.path)[0], description.key
                 )
                 name = f"{device_names.get(device_id, device_id)} {description.name}"
                 entities.append(
@@ -439,9 +436,13 @@ def get_entities(
     elif conf_type == CONF_TYPE_SYSBUS:
         base_dir = config[CONF_MOUNT_DIR]
         _LOGGER.debug("Initializing using SysBus %s", base_dir)
-        for p1sensor in onewirehub.devices:
+        for device in onewirehub.devices:
+            if TYPE_CHECKING:
+                assert isinstance(device, OWDirectDeviceDescription)
+            p1sensor: OneWireInterface = device.interface
             family = p1sensor.mac_address[:2]
             device_id = f"{family}-{p1sensor.mac_address[2:]}"
+            device_info = device.device_info
             if family not in DEVICE_SUPPORT_SYSBUS:
                 _LOGGER.warning(
                     "Ignoring unknown family (%s) of sensor found for device: %s",
@@ -450,12 +451,6 @@ def get_entities(
                 )
                 continue
 
-            device_info = {
-                ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
-                ATTR_MANUFACTURER: "Maxim Integrated",
-                ATTR_MODEL: family,
-                ATTR_NAME: device_id,
-            }
             description = SIMPLE_TEMPERATURE_SENSOR_DESCRIPTION
             device_file = f"/sys/bus/w1/devices/{device_id}/w1_slave"
             name = f"{device_names.get(device_id, device_id)} {description.name}"
diff --git a/homeassistant/components/onewire/switch.py b/homeassistant/components/onewire/switch.py
index aadc1315712..712077c62bd 100644
--- a/homeassistant/components/onewire/switch.py
+++ b/homeassistant/components/onewire/switch.py
@@ -4,19 +4,13 @@ from __future__ import annotations
 from dataclasses import dataclass
 import logging
 import os
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
+from homeassistant.components.onewire.model import OWServerDeviceDescription
 from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
 from homeassistant.config_entries import ConfigEntry
-from homeassistant.const import (
-    ATTR_IDENTIFIERS,
-    ATTR_MANUFACTURER,
-    ATTR_MODEL,
-    ATTR_NAME,
-    CONF_TYPE,
-)
+from homeassistant.const import CONF_TYPE
 from homeassistant.core import HomeAssistant
-from homeassistant.helpers.entity import DeviceInfo
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
 
 from .const import (
@@ -120,23 +114,16 @@ def get_entities(onewirehub: OneWireHub) -> list[SwitchEntity]:
     entities: list[SwitchEntity] = []
 
     for device in onewirehub.devices:
-        family = device["family"]
-        device_type = device["type"]
-        device_id = os.path.split(os.path.split(device["path"])[0])[1]
+        if TYPE_CHECKING:
+            assert isinstance(device, OWServerDeviceDescription)
+        family = device.family
+        device_id = device.id
+        device_info = device.device_info
 
         if family not in DEVICE_SWITCHES:
             continue
-
-        device_info: DeviceInfo = {
-            ATTR_IDENTIFIERS: {(DOMAIN, device_id)},
-            ATTR_MANUFACTURER: "Maxim Integrated",
-            ATTR_MODEL: device_type,
-            ATTR_NAME: device_id,
-        }
         for description in DEVICE_SWITCHES[family]:
-            device_file = os.path.join(
-                os.path.split(device["path"])[0], description.key
-            )
+            device_file = os.path.join(os.path.split(device.path)[0], description.key)
             name = f"{device_id} {description.name}"
             entities.append(
                 OneWireProxySwitch(
diff --git a/tests/components/onewire/__init__.py b/tests/components/onewire/__init__.py
index 8223b1bc841..25ff4a15cfd 100644
--- a/tests/components/onewire/__init__.py
+++ b/tests/components/onewire/__init__.py
@@ -8,8 +8,16 @@ from unittest.mock import MagicMock
 from pyownet.protocol import ProtocolError
 
 from homeassistant.components.onewire.const import DEFAULT_SYSBUS_MOUNT_DIR
-from homeassistant.const import ATTR_ENTITY_ID, ATTR_STATE
+from homeassistant.const import (
+    ATTR_ENTITY_ID,
+    ATTR_IDENTIFIERS,
+    ATTR_MANUFACTURER,
+    ATTR_MODEL,
+    ATTR_NAME,
+    ATTR_STATE,
+)
 from homeassistant.core import HomeAssistant
+from homeassistant.helpers.device_registry import DeviceRegistry
 from homeassistant.helpers.entity_registry import EntityRegistry
 
 from .const import (
@@ -36,6 +44,28 @@ def check_and_enable_disabled_entities(
             entity_registry.async_update_entity(entity_id, **{"disabled_by": None})
 
 
+def check_device_registry(
+    device_registry: DeviceRegistry, expected_devices: list[MappingProxyType]
+) -> None:
+    """Ensure that the expected_devices are correctly registered."""
+    for expected_device in expected_devices:
+        registry_entry = device_registry.async_get_device(
+            expected_device[ATTR_IDENTIFIERS]
+        )
+        assert registry_entry is not None
+        assert registry_entry.identifiers == expected_device[ATTR_IDENTIFIERS]
+        assert registry_entry.manufacturer == expected_device[ATTR_MANUFACTURER]
+        assert registry_entry.name == expected_device[ATTR_NAME]
+        assert registry_entry.model == expected_device[ATTR_MODEL]
+        if expected_via_device := expected_device.get("via_device"):
+            assert registry_entry.via_device_id is not None
+            parent_entry = device_registry.async_get_device({expected_via_device})
+            assert parent_entry is not None
+            assert registry_entry.via_device_id == parent_entry.id
+        else:
+            assert registry_entry.via_device_id is None
+
+
 def check_entities(
     hass: HomeAssistant,
     entity_registry: EntityRegistry,
@@ -57,39 +87,97 @@ def check_entities(
 
 
 def setup_owproxy_mock_devices(
-    owproxy: MagicMock, platform: str, device_ids: list(str)
+    owproxy: MagicMock, platform: str, device_ids: list[str]
 ) -> None:
     """Set up mock for owproxy."""
-    dir_return_value = []
+    main_dir_return_value = []
+    sub_dir_side_effect = []
     main_read_side_effect = []
     sub_read_side_effect = []
 
     for device_id in device_ids:
-        mock_device = MOCK_OWPROXY_DEVICES[device_id]
-
-        # Setup directory listing
-        dir_return_value += [f"/{device_id}/"]
-
-        # Setup device reads
-        main_read_side_effect += [device_id[0:2].encode()]
-        if ATTR_INJECT_READS in mock_device:
-            main_read_side_effect += mock_device[ATTR_INJECT_READS]
-
-        # Setup sub-device reads
-        device_sensors = mock_device.get(platform, [])
-        for expected_sensor in device_sensors:
-            sub_read_side_effect.append(expected_sensor[ATTR_INJECT_READS])
+        _setup_owproxy_mock_device(
+            main_dir_return_value,
+            sub_dir_side_effect,
+            main_read_side_effect,
+            sub_read_side_effect,
+            device_id,
+            platform,
+        )
 
     # Ensure enough read side effect
+    dir_side_effect = [main_dir_return_value] + sub_dir_side_effect
     read_side_effect = (
         main_read_side_effect
         + sub_read_side_effect
         + [ProtocolError("Missing injected value")] * 20
     )
-    owproxy.return_value.dir.return_value = dir_return_value
+    owproxy.return_value.dir.side_effect = dir_side_effect
     owproxy.return_value.read.side_effect = read_side_effect
 
 
+def _setup_owproxy_mock_device(
+    main_dir_return_value: list,
+    sub_dir_side_effect: list,
+    main_read_side_effect: list,
+    sub_read_side_effect: list,
+    device_id: str,
+    platform: str,
+) -> None:
+    """Set up mock for owproxy."""
+    mock_device = MOCK_OWPROXY_DEVICES[device_id]
+
+    # Setup directory listing
+    main_dir_return_value += [f"/{device_id}/"]
+    if "branches" in mock_device:
+        # Setup branch directory listing
+        for branch, branch_details in mock_device["branches"].items():
+            sub_dir_side_effect.append(
+                [  # dir on branch
+                    f"/{device_id}/{branch}/{sub_device_id}/"
+                    for sub_device_id in branch_details
+                ]
+            )
+
+    _setup_owproxy_mock_device_reads(
+        main_read_side_effect,
+        sub_read_side_effect,
+        mock_device,
+        device_id,
+        platform,
+    )
+
+    if "branches" in mock_device:
+        for branch_details in mock_device["branches"].values():
+            for sub_device_id, sub_device in branch_details.items():
+                _setup_owproxy_mock_device_reads(
+                    main_read_side_effect,
+                    sub_read_side_effect,
+                    sub_device,
+                    sub_device_id,
+                    platform,
+                )
+
+
+def _setup_owproxy_mock_device_reads(
+    main_read_side_effect: list,
+    sub_read_side_effect: list,
+    mock_device: Any,
+    device_id: str,
+    platform: str,
+) -> None:
+    """Set up mock for owproxy."""
+    # Setup device reads
+    main_read_side_effect += [device_id[0:2].encode()]
+    if ATTR_INJECT_READS in mock_device:
+        main_read_side_effect += mock_device[ATTR_INJECT_READS]
+
+    # Setup sub-device reads
+    device_sensors = mock_device.get(platform, [])
+    for expected_sensor in device_sensors:
+        sub_read_side_effect.append(expected_sensor[ATTR_INJECT_READS])
+
+
 def setup_sysbus_mock_devices(
     platform: str, device_ids: list[str]
 ) -> tuple[list[str], list[Any]]:
diff --git a/tests/components/onewire/const.py b/tests/components/onewire/const.py
index 93006ee4f81..91b7b618c37 100644
--- a/tests/components/onewire/const.py
+++ b/tests/components/onewire/const.py
@@ -212,12 +212,21 @@ MOCK_OWPROXY_DEVICES = {
         ATTR_INJECT_READS: [
             b"DS2409",  # read device type
         ],
-        ATTR_DEVICE_INFO: {
-            ATTR_IDENTIFIERS: {(DOMAIN, "1F.111111111111")},
-            ATTR_MANUFACTURER: MANUFACTURER,
-            ATTR_MODEL: "DS2409",
-            ATTR_NAME: "1F.111111111111",
-        },
+        ATTR_DEVICE_INFO: [
+            {
+                ATTR_IDENTIFIERS: {(DOMAIN, "1F.111111111111")},
+                ATTR_MANUFACTURER: MANUFACTURER,
+                ATTR_MODEL: "DS2409",
+                ATTR_NAME: "1F.111111111111",
+            },
+            {
+                ATTR_IDENTIFIERS: {(DOMAIN, "1D.111111111111")},
+                ATTR_MANUFACTURER: MANUFACTURER,
+                ATTR_MODEL: "DS2423",
+                ATTR_NAME: "1D.111111111111",
+                "via_device": (DOMAIN, "1F.111111111111"),
+            },
+        ],
         "branches": {
             "aux": {},
             "main": {
@@ -225,12 +234,6 @@ MOCK_OWPROXY_DEVICES = {
                     ATTR_INJECT_READS: [
                         b"DS2423",  # read device type
                     ],
-                    ATTR_DEVICE_INFO: {
-                        ATTR_IDENTIFIERS: {(DOMAIN, "1D.111111111111")},
-                        ATTR_MANUFACTURER: MANUFACTURER,
-                        ATTR_MODEL: "DS2423",
-                        ATTR_NAME: "1D.111111111111",
-                    },
                     SENSOR_DOMAIN: [
                         {
                             ATTR_DEVICE_FILE: "/1F.111111111111/main/1D.111111111111/counter.A",
diff --git a/tests/components/onewire/test_binary_sensor.py b/tests/components/onewire/test_binary_sensor.py
index ff9dd29c2c2..90e53924cab 100644
--- a/tests/components/onewire/test_binary_sensor.py
+++ b/tests/components/onewire/test_binary_sensor.py
@@ -6,15 +6,17 @@ import pytest
 from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.core import HomeAssistant
+from homeassistant.helpers.config_validation import ensure_list
 
 from . import (
     check_and_enable_disabled_entities,
+    check_device_registry,
     check_entities,
     setup_owproxy_mock_devices,
 )
-from .const import MOCK_OWPROXY_DEVICES
+from .const import ATTR_DEVICE_INFO, MOCK_OWPROXY_DEVICES
 
-from tests.common import mock_registry
+from tests.common import mock_device_registry, mock_registry
 
 
 @pytest.fixture(autouse=True)
@@ -31,17 +33,19 @@ async def test_owserver_binary_sensor(
 
     This test forces all entities to be enabled.
     """
+    device_registry = mock_device_registry(hass)
     entity_registry = mock_registry(hass)
 
     mock_device = MOCK_OWPROXY_DEVICES[device_id]
     expected_entities = mock_device.get(BINARY_SENSOR_DOMAIN, [])
+    expected_devices = ensure_list(mock_device.get(ATTR_DEVICE_INFO))
 
     setup_owproxy_mock_devices(owproxy, BINARY_SENSOR_DOMAIN, [device_id])
     await hass.config_entries.async_setup(config_entry.entry_id)
     await hass.async_block_till_done()
 
+    check_device_registry(device_registry, expected_devices)
     assert len(entity_registry.entities) == len(expected_entities)
-
     check_and_enable_disabled_entities(entity_registry, expected_entities)
 
     setup_owproxy_mock_devices(owproxy, BINARY_SENSOR_DOMAIN, [device_id])
diff --git a/tests/components/onewire/test_sensor.py b/tests/components/onewire/test_sensor.py
index 23af72bac41..ffa9d0b5319 100644
--- a/tests/components/onewire/test_sensor.py
+++ b/tests/components/onewire/test_sensor.py
@@ -1,44 +1,24 @@
 """Tests for 1-Wire sensor platform."""
 from unittest.mock import MagicMock, patch
 
-from pyownet.protocol import Error as ProtocolError
 import pytest
 
-from homeassistant.components.onewire.const import DOMAIN
-from homeassistant.components.sensor import ATTR_STATE_CLASS, DOMAIN as SENSOR_DOMAIN
+from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
 from homeassistant.config_entries import ConfigEntry
-from homeassistant.const import (
-    ATTR_DEVICE_CLASS,
-    ATTR_ENTITY_ID,
-    ATTR_MANUFACTURER,
-    ATTR_MODEL,
-    ATTR_NAME,
-    ATTR_STATE,
-    ATTR_UNIT_OF_MEASUREMENT,
-)
 from homeassistant.core import HomeAssistant
+from homeassistant.helpers.config_validation import ensure_list
 
 from . import (
     check_and_enable_disabled_entities,
+    check_device_registry,
     check_entities,
     setup_owproxy_mock_devices,
     setup_sysbus_mock_devices,
 )
-from .const import (
-    ATTR_DEVICE_FILE,
-    ATTR_DEVICE_INFO,
-    ATTR_INJECT_READS,
-    ATTR_UNIQUE_ID,
-    MOCK_OWPROXY_DEVICES,
-    MOCK_SYSBUS_DEVICES,
-)
+from .const import ATTR_DEVICE_INFO, MOCK_OWPROXY_DEVICES, MOCK_SYSBUS_DEVICES
 
 from tests.common import mock_device_registry, mock_registry
 
-MOCK_COUPLERS = {
-    key: value for (key, value) in MOCK_OWPROXY_DEVICES.items() if "branches" in value
-}
-
 
 @pytest.fixture(autouse=True)
 def override_platforms():
@@ -47,99 +27,36 @@ def override_platforms():
         yield
 
 
-@pytest.mark.parametrize("device_id", ["1F.111111111111"], indirect=True)
-async def test_sensors_on_owserver_coupler(
-    hass: HomeAssistant, config_entry: ConfigEntry, owproxy: MagicMock, device_id: str
-):
-    """Test for 1-Wire sensors connected to DS2409 coupler."""
-
-    entity_registry = mock_registry(hass)
-
-    mock_coupler = MOCK_COUPLERS[device_id]
-
-    dir_side_effect = []  # List of lists of string
-    read_side_effect = []  # List of byte arrays
-
-    dir_side_effect.append([f"/{device_id}/"])  # dir on root
-    read_side_effect.append(device_id[0:2].encode())  # read family on root
-    if ATTR_INJECT_READS in mock_coupler:
-        read_side_effect += mock_coupler[ATTR_INJECT_READS]
-
-    expected_entities = []
-    for branch, branch_details in mock_coupler["branches"].items():
-        dir_side_effect.append(
-            [  # dir on branch
-                f"/{device_id}/{branch}/{sub_device_id}/"
-                for sub_device_id in branch_details
-            ]
-        )
-
-        for sub_device_id, sub_device in branch_details.items():
-            read_side_effect.append(sub_device_id[0:2].encode())
-            if ATTR_INJECT_READS in sub_device:
-                read_side_effect.extend(sub_device[ATTR_INJECT_READS])
-
-            expected_entities += sub_device[SENSOR_DOMAIN]
-            for expected_entity in sub_device[SENSOR_DOMAIN]:
-                read_side_effect.append(expected_entity[ATTR_INJECT_READS])
-
-    # Ensure enough read side effect
-    read_side_effect.extend([ProtocolError("Missing injected value")] * 10)
-    owproxy.return_value.dir.side_effect = dir_side_effect
-    owproxy.return_value.read.side_effect = read_side_effect
-
-    await hass.config_entries.async_setup(config_entry.entry_id)
-    await hass.async_block_till_done()
-
-    assert len(entity_registry.entities) == len(expected_entities)
-
-    for expected_entity in expected_entities:
-        entity_id = expected_entity[ATTR_ENTITY_ID]
-        registry_entry = entity_registry.entities.get(entity_id)
-        assert registry_entry is not None
-        assert registry_entry.unique_id == expected_entity[ATTR_UNIQUE_ID]
-        state = hass.states.get(entity_id)
-        assert state.state == expected_entity[ATTR_STATE]
-        for attr in (ATTR_DEVICE_CLASS, ATTR_STATE_CLASS, ATTR_UNIT_OF_MEASUREMENT):
-            assert state.attributes.get(attr) == expected_entity.get(attr)
-        assert state.attributes[ATTR_DEVICE_FILE] == expected_entity[ATTR_DEVICE_FILE]
-
-
-async def test_owserver_setup_valid_device(
+async def test_owserver_sensor(
     hass: HomeAssistant, config_entry: ConfigEntry, owproxy: MagicMock, device_id: str
 ):
     """Test for 1-Wire device.
 
     As they would be on a clean setup: all binary-sensors and switches disabled.
     """
-    entity_registry = mock_registry(hass)
     device_registry = mock_device_registry(hass)
+    entity_registry = mock_registry(hass)
 
     mock_device = MOCK_OWPROXY_DEVICES[device_id]
     expected_entities = mock_device.get(SENSOR_DOMAIN, [])
+    if "branches" in mock_device:
+        for branch_details in mock_device["branches"].values():
+            for sub_device in branch_details.values():
+                expected_entities += sub_device[SENSOR_DOMAIN]
+    expected_devices = ensure_list(mock_device.get(ATTR_DEVICE_INFO))
 
     setup_owproxy_mock_devices(owproxy, SENSOR_DOMAIN, [device_id])
     await hass.config_entries.async_setup(config_entry.entry_id)
     await hass.async_block_till_done()
 
+    check_device_registry(device_registry, expected_devices)
     assert len(entity_registry.entities) == len(expected_entities)
-
     check_and_enable_disabled_entities(entity_registry, expected_entities)
 
     setup_owproxy_mock_devices(owproxy, SENSOR_DOMAIN, [device_id])
     await hass.config_entries.async_reload(config_entry.entry_id)
     await hass.async_block_till_done()
 
-    if len(expected_entities) > 0:
-        device_info = mock_device[ATTR_DEVICE_INFO]
-        assert len(device_registry.devices) == 1
-        registry_entry = device_registry.async_get_device({(DOMAIN, device_id)})
-        assert registry_entry is not None
-        assert registry_entry.identifiers == {(DOMAIN, device_id)}
-        assert registry_entry.manufacturer == device_info[ATTR_MANUFACTURER]
-        assert registry_entry.name == device_info[ATTR_NAME]
-        assert registry_entry.model == device_info[ATTR_MODEL]
-
     check_entities(hass, entity_registry, expected_entities)
 
 
@@ -149,9 +66,8 @@ async def test_onewiredirect_setup_valid_device(
     hass: HomeAssistant, sysbus_config_entry: ConfigEntry, device_id: str
 ):
     """Test that sysbus config entry works correctly."""
-
-    entity_registry = mock_registry(hass)
     device_registry = mock_device_registry(hass)
+    entity_registry = mock_registry(hass)
 
     glob_result, read_side_effect = setup_sysbus_mock_devices(
         SENSOR_DOMAIN, [device_id]
@@ -159,6 +75,7 @@ async def test_onewiredirect_setup_valid_device(
 
     mock_device = MOCK_SYSBUS_DEVICES[device_id]
     expected_entities = mock_device.get(SENSOR_DOMAIN, [])
+    expected_devices = ensure_list(mock_device.get(ATTR_DEVICE_INFO))
 
     with patch("pi1wire._finder.glob.glob", return_value=glob_result,), patch(
         "pi1wire.OneWire.get_temperature",
@@ -167,16 +84,6 @@ async def test_onewiredirect_setup_valid_device(
         await hass.config_entries.async_setup(sysbus_config_entry.entry_id)
         await hass.async_block_till_done()
 
+    check_device_registry(device_registry, expected_devices)
     assert len(entity_registry.entities) == len(expected_entities)
-
-    if len(expected_entities) > 0:
-        device_info = mock_device[ATTR_DEVICE_INFO]
-        assert len(device_registry.devices) == 1
-        registry_entry = device_registry.async_get_device({(DOMAIN, device_id)})
-        assert registry_entry is not None
-        assert registry_entry.identifiers == {(DOMAIN, device_id)}
-        assert registry_entry.manufacturer == device_info[ATTR_MANUFACTURER]
-        assert registry_entry.name == device_info[ATTR_NAME]
-        assert registry_entry.model == device_info[ATTR_MODEL]
-
     check_entities(hass, entity_registry, expected_entities)
diff --git a/tests/components/onewire/test_switch.py b/tests/components/onewire/test_switch.py
index 766a41a5862..ffe5042b514 100644
--- a/tests/components/onewire/test_switch.py
+++ b/tests/components/onewire/test_switch.py
@@ -13,15 +13,17 @@ from homeassistant.const import (
     STATE_ON,
 )
 from homeassistant.core import HomeAssistant
+from homeassistant.helpers.config_validation import ensure_list
 
 from . import (
     check_and_enable_disabled_entities,
+    check_device_registry,
     check_entities,
     setup_owproxy_mock_devices,
 )
-from .const import MOCK_OWPROXY_DEVICES
+from .const import ATTR_DEVICE_INFO, MOCK_OWPROXY_DEVICES
 
-from tests.common import mock_registry
+from tests.common import mock_device_registry, mock_registry
 
 
 @pytest.fixture(autouse=True)
@@ -38,17 +40,19 @@ async def test_owserver_switch(
 
     This test forces all entities to be enabled.
     """
+    device_registry = mock_device_registry(hass)
     entity_registry = mock_registry(hass)
 
     mock_device = MOCK_OWPROXY_DEVICES[device_id]
     expected_entities = mock_device.get(SWITCH_DOMAIN, [])
+    expected_devices = ensure_list(mock_device.get(ATTR_DEVICE_INFO))
 
     setup_owproxy_mock_devices(owproxy, SWITCH_DOMAIN, [device_id])
     await hass.config_entries.async_setup(config_entry.entry_id)
     await hass.async_block_till_done()
 
+    check_device_registry(device_registry, expected_devices)
     assert len(entity_registry.entities) == len(expected_entities)
-
     check_and_enable_disabled_entities(entity_registry, expected_entities)
 
     setup_owproxy_mock_devices(owproxy, SWITCH_DOMAIN, [device_id])