Use asyncio lock (#21985)

This commit is contained in:
Anders Melchiorsen 2019-03-13 10:17:09 +01:00 committed by GitHub
parent 0162e2abe5
commit c8692fe70c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 119 deletions

View file

@ -3,7 +3,7 @@ import datetime
import functools as ft
import logging
import socket
import threading
import asyncio
import urllib
import requests
@ -111,11 +111,11 @@ SONOS_SET_OPTION_SCHEMA = SONOS_SCHEMA.extend({
class SonosData:
"""Storage class for platform global data."""
def __init__(self):
def __init__(self, hass):
"""Initialize the data."""
self.uids = set()
self.entities = []
self.topology_lock = threading.Lock()
self.topology_lock = asyncio.Lock(loop=hass.loop)
def setup_platform(hass, config, add_entities, discovery_info=None):
@ -143,7 +143,7 @@ def _setup_platform(hass, config, add_entities, discovery_info):
import pysonos
if DATA_SONOS not in hass.data:
hass.data[DATA_SONOS] = SonosData()
hass.data[DATA_SONOS] = SonosData(hass)
advertise_addr = config.get(CONF_ADVERTISE_ADDR)
if advertise_addr:
@ -187,57 +187,62 @@ def _setup_platform(hass, config, add_entities, discovery_info):
add_entities(SonosEntity(p) for p in players)
_LOGGER.debug("Added %s Sonos speakers", len(players))
def service_handle(service):
"""Handle for services."""
def _service_to_entities(service):
"""Extract and return entities from service call."""
entity_ids = service.data.get('entity_id')
entities = hass.data[DATA_SONOS].entities
if entity_ids:
entities = [e for e in entities if e.entity_id in entity_ids]
with hass.data[DATA_SONOS].topology_lock:
if service.service == SERVICE_SNAPSHOT:
SonosEntity.snapshot_multi(
entities, service.data[ATTR_WITH_GROUP])
elif service.service == SERVICE_RESTORE:
SonosEntity.restore_multi(
entities, service.data[ATTR_WITH_GROUP])
elif service.service == SERVICE_JOIN:
master = [e for e in hass.data[DATA_SONOS].entities
if e.entity_id == service.data[ATTR_MASTER]]
if master:
master[0].join(entities)
else:
for entity in entities:
if service.service == SERVICE_UNJOIN:
entity.unjoin()
elif service.service == SERVICE_SET_TIMER:
entity.set_sleep_timer(service.data[ATTR_SLEEP_TIME])
elif service.service == SERVICE_CLEAR_TIMER:
entity.clear_sleep_timer()
elif service.service == SERVICE_UPDATE_ALARM:
entity.set_alarm(**service.data)
elif service.service == SERVICE_SET_OPTION:
entity.set_option(**service.data)
return entities
entity.schedule_update_ha_state(True)
async def async_service_handle(service):
"""Handle async services."""
entities = _service_to_entities(service)
if service.service == SERVICE_JOIN:
master = [e for e in hass.data[DATA_SONOS].entities
if e.entity_id == service.data[ATTR_MASTER]]
if master:
await SonosEntity.join_multi(hass, master[0], entities)
elif service.service == SERVICE_UNJOIN:
await SonosEntity.unjoin_multi(hass, entities)
elif service.service == SERVICE_SNAPSHOT:
await SonosEntity.snapshot_multi(
hass, entities, service.data[ATTR_WITH_GROUP])
elif service.service == SERVICE_RESTORE:
await SonosEntity.restore_multi(
hass, entities, service.data[ATTR_WITH_GROUP])
hass.services.register(
DOMAIN, SERVICE_JOIN, service_handle,
DOMAIN, SERVICE_JOIN, async_service_handle,
schema=SONOS_JOIN_SCHEMA)
hass.services.register(
DOMAIN, SERVICE_UNJOIN, service_handle,
DOMAIN, SERVICE_UNJOIN, async_service_handle,
schema=SONOS_SCHEMA)
hass.services.register(
DOMAIN, SERVICE_SNAPSHOT, service_handle,
DOMAIN, SERVICE_SNAPSHOT, async_service_handle,
schema=SONOS_STATES_SCHEMA)
hass.services.register(
DOMAIN, SERVICE_RESTORE, service_handle,
DOMAIN, SERVICE_RESTORE, async_service_handle,
schema=SONOS_STATES_SCHEMA)
def service_handle(service):
"""Handle sync services."""
for entity in _service_to_entities(service):
if service.service == SERVICE_SET_TIMER:
entity.set_sleep_timer(service.data[ATTR_SLEEP_TIME])
elif service.service == SERVICE_CLEAR_TIMER:
entity.clear_sleep_timer()
elif service.service == SERVICE_UPDATE_ALARM:
entity.set_alarm(**service.data)
elif service.service == SERVICE_SET_OPTION:
entity.set_option(**service.data)
hass.services.register(
DOMAIN, SERVICE_SET_TIMER, service_handle,
schema=SONOS_SET_TIMER_SCHEMA)
@ -701,52 +706,68 @@ class SonosEntity(MediaPlayerDevice):
self._speech_enhance = self.soco.dialog_mode
def update_groups(self, event=None):
"""Process a zone group topology event coming from a player."""
"""Handle callback for topology change event."""
def _get_soco_group():
"""Ask SoCo cache for existing topology."""
coordinator_uid = self.unique_id
slave_uids = []
try:
if self.soco.group and self.soco.group.coordinator:
coordinator_uid = self.soco.group.coordinator.uid
slave_uids = [p.uid for p in self.soco.group.members
if p.uid != coordinator_uid]
except requests.exceptions.RequestException:
pass
return [coordinator_uid] + slave_uids
async def _async_extract_group(event):
"""Extract group layout from a topology event."""
group = event and event.zone_player_uui_ds_in_group
if group:
return group.split(',')
return await self.hass.async_add_executor_job(_get_soco_group)
def _async_regroup(group):
"""Rebuild internal group layout."""
sonos_group = []
for uid in group:
entity = _get_entity_from_soco_uid(self.hass, uid)
if entity:
sonos_group.append(entity)
self._coordinator = None
self._sonos_group = sonos_group
self.async_schedule_update_ha_state()
for slave_uid in group[1:]:
slave = _get_entity_from_soco_uid(self.hass, slave_uid)
if slave:
# pylint: disable=protected-access
slave._coordinator = self
slave._sonos_group = sonos_group
slave.async_schedule_update_ha_state()
async def _async_handle_group_event(event):
"""Get async lock and handle event."""
async with self.hass.data[DATA_SONOS].topology_lock:
group = await _async_extract_group(event)
if self.unique_id == group[0]:
if self._restore_pending:
await self.hass.async_add_executor_job(self.restore)
_async_regroup(group)
if event:
self._receives_events = True
if not hasattr(event, 'zone_player_uui_ds_in_group'):
return
with self.hass.data[DATA_SONOS].topology_lock:
group = event and event.zone_player_uui_ds_in_group
if group:
# New group information is pushed
coordinator_uid, *slave_uids = group.split(',')
else:
coordinator_uid = self.unique_id
slave_uids = []
# Try SoCo cache for existing topology
try:
if self.soco.group and self.soco.group.coordinator:
coordinator_uid = self.soco.group.coordinator.uid
slave_uids = [p.uid for p in self.soco.group.members
if p.uid != coordinator_uid]
except requests.exceptions.RequestException:
pass
if self.unique_id == coordinator_uid:
if self._restore_pending:
self.restore()
sonos_group = []
for uid in (coordinator_uid, *slave_uids):
entity = _get_entity_from_soco_uid(self.hass, uid)
if entity:
sonos_group.append(entity)
self._coordinator = None
self._sonos_group = sonos_group
self.schedule_update_ha_state()
for slave_uid in slave_uids:
slave = _get_entity_from_soco_uid(self.hass, slave_uid)
if slave:
# pylint: disable=protected-access
slave._coordinator = self
slave._sonos_group = sonos_group
slave.schedule_update_ha_state()
self.hass.add_job(_async_handle_group_event(event))
def update_content(self, event=None):
"""Update information about available content."""
@ -974,12 +995,29 @@ class SonosEntity(MediaPlayerDevice):
# pylint: disable=protected-access
slave._coordinator = self
@staticmethod
async def join_multi(hass, master, entities):
"""Form a group with other players."""
async with hass.data[DATA_SONOS].topology_lock:
await hass.async_add_executor_job(master.join, entities)
@soco_error()
def unjoin(self):
"""Unjoin the player from a group."""
self.soco.unjoin()
self._coordinator = None
@staticmethod
async def unjoin_multi(hass, entities):
"""Unjoin several players from their group."""
def _unjoin_all(entities):
"""Sync helper."""
for entity in entities:
entity.unjoin()
async with hass.data[DATA_SONOS].topology_lock:
await hass.async_add_executor_job(_unjoin_all, entities)
@soco_error()
def snapshot(self, with_group):
"""Snapshot the state of a player."""
@ -992,6 +1030,25 @@ class SonosEntity(MediaPlayerDevice):
else:
self._snapshot_group = None
@staticmethod
async def snapshot_multi(hass, entities, with_group):
"""Snapshot all the entities and optionally their groups."""
# pylint: disable=protected-access
def _snapshot_all(entities):
"""Sync helper."""
for entity in entities:
entity.snapshot(with_group)
# Find all affected players
entities = set(entities)
if with_group:
for entity in list(entities):
entities.update(entity._sonos_group)
async with hass.data[DATA_SONOS].topology_lock:
await hass.async_add_executor_job(_snapshot_all, entities)
@soco_error()
def restore(self):
"""Restore a snapshotted state to a player."""
@ -1010,56 +1067,49 @@ class SonosEntity(MediaPlayerDevice):
self._restore_pending = False
@staticmethod
def snapshot_multi(entities, with_group):
"""Snapshot all the entities and optionally their groups."""
# pylint: disable=protected-access
# Find all affected players
entities = set(entities)
if with_group:
for entity in list(entities):
entities.update(entity._sonos_group)
for entity in entities:
entity.snapshot(with_group)
@staticmethod
def restore_multi(entities, with_group):
async def restore_multi(hass, entities, with_group):
"""Restore snapshots for all the entities."""
# pylint: disable=protected-access
def _restore_all(entities):
"""Sync helper."""
# Pause all current coordinators
for entity in (e for e in entities if e.is_coordinator):
if entity.state == STATE_PLAYING:
entity.media_pause()
if with_group:
# Unjoin slaves that are not already in their target group
for entity in [e for e in entities if not e.is_coordinator]:
if entity._snapshot_group != entity._sonos_group:
entity.unjoin()
# Bring back the original group topology
for entity in (e for e in entities if e._snapshot_group):
if entity._snapshot_group[0] == entity:
entity.join(entity._snapshot_group)
# Restore slaves
for entity in (e for e in entities if not e.is_coordinator):
entity.restore()
# Restore coordinators (or delay if moving from slave)
for entity in (e for e in entities if e.is_coordinator):
if entity._sonos_group[0] == entity:
# Was already coordinator
entity.restore()
else:
# Await coordinator role
entity._restore_pending = True
# Find all affected players
entities = set(e for e in entities if e._soco_snapshot)
if with_group:
for entity in [e for e in entities if e._snapshot_group]:
entities.update(entity._snapshot_group)
# Pause all current coordinators
for entity in (e for e in entities if e.is_coordinator):
if entity.state == STATE_PLAYING:
entity.media_pause()
if with_group:
# Unjoin slaves that are not already in their target group
for entity in [e for e in entities if not e.is_coordinator]:
if entity._snapshot_group != entity._sonos_group:
entity.unjoin()
# Bring back the original group topology
for entity in (e for e in entities if e._snapshot_group):
if entity._snapshot_group[0] == entity:
entity.join(entity._snapshot_group)
# Restore slaves
for entity in (e for e in entities if not e.is_coordinator):
entity.restore()
# Restore coordinators (or delay if moving from slave)
for entity in (e for e in entities if e.is_coordinator):
if entity._sonos_group[0] == entity:
# Was already coordinator
entity.restore()
else:
# Await coordinator role
entity._restore_pending = True
async with hass.data[DATA_SONOS].topology_lock:
await hass.async_add_executor_job(_restore_all, entities)
@soco_error()
@soco_coordinator

View file

@ -12,6 +12,7 @@ from homeassistant.components.sonos import media_player as sonos
from homeassistant.components.media_player.const import DOMAIN
from homeassistant.components.sonos.media_player import CONF_INTERFACE_ADDR
from homeassistant.const import CONF_HOSTS, CONF_PLATFORM
from homeassistant.util.async_ import run_coroutine_threadsafe
from tests.common import get_test_home_assistant
@ -328,7 +329,9 @@ class TestSonosMediaPlayer(unittest.TestCase):
snapshotMock.return_value = True
entity.soco.group = mock.MagicMock()
entity.soco.group.members = [e.soco for e in entities]
sonos.SonosEntity.snapshot_multi(entities, True)
run_coroutine_threadsafe(
sonos.SonosEntity.snapshot_multi(self.hass, entities, True),
self.hass.loop).result()
assert snapshotMock.call_count == 1
assert snapshotMock.call_args == mock.call()
@ -350,6 +353,8 @@ class TestSonosMediaPlayer(unittest.TestCase):
entity._snapshot_group = mock.MagicMock()
entity._snapshot_group.members = [e.soco for e in entities]
entity._soco_snapshot = Snapshot(entity.soco)
sonos.SonosEntity.restore_multi(entities, True)
run_coroutine_threadsafe(
sonos.SonosEntity.restore_multi(self.hass, entities, True),
self.hass.loop).result()
assert restoreMock.call_count == 1
assert restoreMock.call_args == mock.call()