From 869f970e59153c45d1d57f81a1cbdbc74834e47d Mon Sep 17 00:00:00 2001 From: luar123 <49960470+luar123@users.noreply.github.com> Date: Wed, 24 May 2023 08:16:09 +0200 Subject: [PATCH] Fix Snapcast connection issues (#93010) * Add (dis)connect and update listeners, terminate connection and reconnect. Set availability * Pass entry_id to constructor --- homeassistant/components/snapcast/__init__.py | 8 +- .../components/snapcast/media_player.py | 43 ++++-- homeassistant/components/snapcast/server.py | 140 +++++++++++++++++- 3 files changed, 173 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/snapcast/__init__.py b/homeassistant/components/snapcast/__init__.py index 309669a8496..d8ff55cc175 100644 --- a/homeassistant/components/snapcast/__init__.py +++ b/homeassistant/components/snapcast/__init__.py @@ -27,7 +27,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: f"Could not connect to Snapcast server at {host}:{port}" ) from ex - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = HomeAssistantSnapcast(server) + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = HomeAssistantSnapcast( + hass, server, f"{host}:{port}", entry.entry_id + ) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) @@ -37,5 +39,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): - hass.data[DOMAIN].pop(entry.entry_id) + snapcast_data = hass.data[DOMAIN].pop(entry.entry_id) + # disconnect from server + await snapcast_data.disconnect() return unload_ok diff --git a/homeassistant/components/snapcast/media_player.py b/homeassistant/components/snapcast/media_player.py index bb54bfabf9f..624bf7463ba 100644 --- a/homeassistant/components/snapcast/media_player.py +++ b/homeassistant/components/snapcast/media_player.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from snapcast.control.server import CONTROL_PORT +from snapcast.control.server import CONTROL_PORT, Snapserver import voluptuous as vol from homeassistant.components.media_player import ( @@ -34,7 +34,6 @@ from .const import ( SERVICE_SNAPSHOT, SERVICE_UNJOIN, ) -from .server import HomeAssistantSnapcast _LOGGER = logging.getLogger(__name__) @@ -72,7 +71,7 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the snapcast config entry.""" - snapcast_data: HomeAssistantSnapcast = hass.data[DOMAIN][config_entry.entry_id] + snapcast_server: Snapserver = hass.data[DOMAIN][config_entry.entry_id].server register_services() @@ -80,14 +79,18 @@ async def async_setup_entry( port = config_entry.data[CONF_PORT] hpid = f"{host}:{port}" - snapcast_data.groups = [ - SnapcastGroupDevice(group, hpid) for group in snapcast_data.server.groups + groups: list[MediaPlayerEntity] = [ + SnapcastGroupDevice(group, hpid, config_entry.entry_id) + for group in snapcast_server.groups ] - snapcast_data.clients = [ + clients: list[MediaPlayerEntity] = [ SnapcastClientDevice(client, hpid, config_entry.entry_id) - for client in snapcast_data.server.clients + for client in snapcast_server.clients ] - async_add_entities(snapcast_data.clients + snapcast_data.groups) + async_add_entities(clients + groups) + hass.data[DOMAIN][ + config_entry.entry_id + ].hass_async_add_entities = async_add_entities async def async_setup_platform( @@ -147,18 +150,27 @@ class SnapcastGroupDevice(MediaPlayerEntity): | MediaPlayerEntityFeature.SELECT_SOURCE ) - def __init__(self, group, uid_part): + def __init__(self, group, uid_part, entry_id): """Initialize the Snapcast group device.""" + self._attr_available = True self._group = group + self._entry_id = entry_id self._uid = f"{GROUP_PREFIX}{uid_part}_{self._group.identifier}" async def async_added_to_hass(self) -> None: """Subscribe to group events.""" self._group.set_callback(self.schedule_update_ha_state) + self.hass.data[DOMAIN][self._entry_id].groups.append(self) async def async_will_remove_from_hass(self) -> None: """Disconnect group object when removed.""" self._group.set_callback(None) + self.hass.data[DOMAIN][self._entry_id].groups.remove(self) + + def set_availability(self, available: bool) -> None: + """Set availability of group.""" + self._attr_available = available + self.schedule_update_ha_state() @property def state(self) -> MediaPlayerState | None: @@ -172,6 +184,11 @@ class SnapcastGroupDevice(MediaPlayerEntity): """Return the ID of snapcast group.""" return self._uid + @property + def identifier(self): + """Return the snapcast identifier.""" + return self._group.identifier + @property def name(self): """Return the name of the device.""" @@ -236,6 +253,7 @@ class SnapcastClientDevice(MediaPlayerEntity): def __init__(self, client, uid_part, entry_id): """Initialize the Snapcast client device.""" + self._attr_available = True self._client = client self._uid = f"{CLIENT_PREFIX}{uid_part}_{self._client.identifier}" self._entry_id = entry_id @@ -243,10 +261,17 @@ class SnapcastClientDevice(MediaPlayerEntity): async def async_added_to_hass(self) -> None: """Subscribe to client events.""" self._client.set_callback(self.schedule_update_ha_state) + self.hass.data[DOMAIN][self._entry_id].clients.append(self) async def async_will_remove_from_hass(self) -> None: """Disconnect client object when removed.""" self._client.set_callback(None) + self.hass.data[DOMAIN][self._entry_id].clients.remove(self) + + def set_availability(self, available: bool) -> None: + """Set availability of group.""" + self._attr_available = available + self.schedule_update_ha_state() @property def unique_id(self): diff --git a/homeassistant/components/snapcast/server.py b/homeassistant/components/snapcast/server.py index 507ad6393a2..6a787dd5e88 100644 --- a/homeassistant/components/snapcast/server.py +++ b/homeassistant/components/snapcast/server.py @@ -1,15 +1,141 @@ """Snapcast Integration.""" -from dataclasses import dataclass, field +from __future__ import annotations -from snapcast.control import Snapserver +import logging + +import snapcast.control +from snapcast.control.client import Snapclient from homeassistant.components.media_player import MediaPlayerEntity +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .media_player import SnapcastClientDevice, SnapcastGroupDevice + +_LOGGER = logging.getLogger(__name__) -@dataclass class HomeAssistantSnapcast: - """Snapcast data stored in the Home Assistant data object.""" + """Snapcast server and data stored in the Home Assistant data object.""" - server: Snapserver - clients: list[MediaPlayerEntity] = field(default_factory=list) - groups: list[MediaPlayerEntity] = field(default_factory=list) + hass: HomeAssistant + + def __init__( + self, + hass: HomeAssistant, + server: snapcast.control.Snapserver, + hpid: str, + entry_id: str, + ) -> None: + """Initialize the HomeAssistantSnapcast object. + + Parameters + ---------- + hass: HomeAssistant + hass object + server : snapcast.control.Snapserver + Snapcast server + hpid : str + host and port + entry_id: str + ConfigEntry entry_id + + Returns + ------- + None + + """ + self.hass: HomeAssistant = hass + self.server: snapcast.control.Snapserver = server + self.hpid: str = hpid + self._entry_id = entry_id + self.clients: list[SnapcastClientDevice] = [] + self.groups: list[SnapcastGroupDevice] = [] + self.hass_async_add_entities: AddEntitiesCallback + # connect callbacks + self.server.set_on_update_callback(self.on_update) + self.server.set_on_connect_callback(self.on_connect) + self.server.set_on_disconnect_callback(self.on_disconnect) + self.server.set_new_client_callback(self.on_add_client) + + async def disconnect(self) -> None: + """Disconnect from server.""" + self.server.set_on_update_callback(None) + self.server.set_on_connect_callback(None) + self.server.set_on_disconnect_callback(None) + self.server.set_new_client_callback(None) + await self.server.stop() + + def on_update(self) -> None: + """Update all entities. + + Retrieve all groups/clients from server and add/update/delete entities. + """ + if not self.hass_async_add_entities: + return + new_groups: list[MediaPlayerEntity] = [] + groups: list[MediaPlayerEntity] = [] + hass_groups = {g.identifier: g for g in self.groups} + for group in self.server.groups: + if group.identifier in hass_groups: + groups.append(hass_groups[group.identifier]) + hass_groups[group.identifier].async_schedule_update_ha_state() + else: + new_groups.append(SnapcastGroupDevice(group, self.hpid, self._entry_id)) + new_clients: list[MediaPlayerEntity] = [] + clients: list[MediaPlayerEntity] = [] + hass_clients = {c.identifier: c for c in self.clients} + for client in self.server.clients: + if client.identifier in hass_clients: + clients.append(hass_clients[client.identifier]) + hass_clients[client.identifier].async_schedule_update_ha_state() + else: + new_clients.append( + SnapcastClientDevice(client, self.hpid, self._entry_id) + ) + del_entities: list[MediaPlayerEntity] = [ + x for x in self.groups if x not in groups + ] + del_entities.extend([x for x in self.clients if x not in clients]) + + _LOGGER.debug("New clients: %s", str(new_clients)) + _LOGGER.debug("New groups: %s", str(new_groups)) + _LOGGER.debug("Delete: %s", str(del_entities)) + + ent_reg = er.async_get(self.hass) + for entity in del_entities: + ent_reg.async_remove(entity.entity_id) + self.hass_async_add_entities(new_clients + new_groups) + + def on_connect(self) -> None: + """Activate all entities and update.""" + for client in self.clients: + client.set_availability(True) + for group in self.groups: + group.set_availability(True) + _LOGGER.info("Server connected: %s", self.hpid) + self.on_update() + + def on_disconnect(self, ex: Exception | None) -> None: + """Deactivate all entities.""" + for client in self.clients: + client.set_availability(False) + for group in self.groups: + group.set_availability(False) + _LOGGER.warning( + "Server disconnected: %s. Trying to reconnect. %s", self.hpid, str(ex or "") + ) + + def on_add_client(self, client: Snapclient) -> None: + """Add a Snapcast client. + + Parameters + ---------- + client : Snapclient + Snapcast client to be added to HA. + """ + if not self.hass_async_add_entities: + return + clients = [SnapcastClientDevice(client, self.hpid, self._entry_id)] + self.hass_async_add_entities(clients)