Remove legacy reproduce state (#28458)
* Remove legacy reproduce state * Fix imports
This commit is contained in:
parent
9b72a55d60
commit
552fbda58b
3 changed files with 12 additions and 143 deletions
|
@ -14,7 +14,7 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
SERVICE_RELOAD,
|
||||
)
|
||||
from homeassistant.core import State, DOMAIN
|
||||
from homeassistant.core import State, DOMAIN as HA_DOMAIN
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import async_get_integration
|
||||
|
@ -23,7 +23,7 @@ from homeassistant.helpers import (
|
|||
config_validation as cv,
|
||||
entity_platform,
|
||||
)
|
||||
from homeassistant.helpers.state import HASS_DOMAIN, async_reproduce_state
|
||||
from homeassistant.helpers.state import async_reproduce_state
|
||||
from homeassistant.components.scene import DOMAIN as SCENE_DOMAIN, STATES, Scene
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ STATES_SCHEMA = vol.All(dict, _convert_states)
|
|||
|
||||
PLATFORM_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): HASS_DOMAIN,
|
||||
vol.Required(CONF_PLATFORM): HA_DOMAIN,
|
||||
vol.Required(STATES): vol.All(
|
||||
cv.ensure_list,
|
||||
[
|
||||
|
@ -114,7 +114,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
|||
|
||||
# Extract only the config for the Home Assistant platform, ignore the rest.
|
||||
for p_type, p_config in config_per_platform(conf, SCENE_DOMAIN):
|
||||
if p_type != DOMAIN:
|
||||
if p_type != HA_DOMAIN:
|
||||
continue
|
||||
|
||||
_process_scenes_config(hass, async_add_entities, p_config)
|
||||
|
|
|
@ -8,7 +8,6 @@ from homeassistant.core import DOMAIN as HA_DOMAIN
|
|||
from homeassistant.const import CONF_PLATFORM, SERVICE_TURN_ON
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.state import HASS_DOMAIN
|
||||
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
@ -21,7 +20,7 @@ STATES = "states"
|
|||
def _hass_domain_validator(config):
|
||||
"""Validate platform in config for homeassistant domain."""
|
||||
if CONF_PLATFORM not in config:
|
||||
config = {CONF_PLATFORM: HASS_DOMAIN, STATES: config}
|
||||
config = {CONF_PLATFORM: HA_DOMAIN, STATES: config}
|
||||
|
||||
return config
|
||||
|
||||
|
|
|
@ -1,35 +1,15 @@
|
|||
"""Helpers that help with state related things."""
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from types import ModuleType, TracebackType
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from homeassistant.loader import bind_hass, async_get_integration, IntegrationNotFound
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.components.notify import ATTR_MESSAGE, SERVICE_NOTIFY
|
||||
from homeassistant.components.sun import STATE_ABOVE_HORIZON, STATE_BELOW_HORIZON
|
||||
from homeassistant.components.cover import ATTR_POSITION, ATTR_TILT_POSITION
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
SERVICE_ALARM_ARM_AWAY,
|
||||
SERVICE_ALARM_ARM_HOME,
|
||||
SERVICE_ALARM_DISARM,
|
||||
SERVICE_ALARM_TRIGGER,
|
||||
SERVICE_LOCK,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
SERVICE_UNLOCK,
|
||||
SERVICE_OPEN_COVER,
|
||||
SERVICE_CLOSE_COVER,
|
||||
SERVICE_SET_COVER_POSITION,
|
||||
SERVICE_SET_COVER_TILT_POSITION,
|
||||
STATE_ALARM_ARMED_AWAY,
|
||||
STATE_ALARM_ARMED_HOME,
|
||||
STATE_ALARM_DISARMED,
|
||||
STATE_ALARM_TRIGGERED,
|
||||
STATE_CLOSED,
|
||||
STATE_HOME,
|
||||
STATE_LOCKED,
|
||||
|
@ -40,36 +20,11 @@ from homeassistant.const import (
|
|||
STATE_UNKNOWN,
|
||||
STATE_UNLOCKED,
|
||||
)
|
||||
from homeassistant.core import Context, State, DOMAIN as HASS_DOMAIN
|
||||
from homeassistant.core import Context, State
|
||||
from .typing import HomeAssistantType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
GROUP_DOMAIN = "group"
|
||||
|
||||
# Update this dict of lists when new services are added to HA.
|
||||
# Each item is a service with a list of required attributes.
|
||||
SERVICE_ATTRIBUTES = {
|
||||
SERVICE_NOTIFY: [ATTR_MESSAGE],
|
||||
SERVICE_SET_COVER_POSITION: [ATTR_POSITION],
|
||||
SERVICE_SET_COVER_TILT_POSITION: [ATTR_TILT_POSITION],
|
||||
}
|
||||
|
||||
# Update this dict when new services are added to HA.
|
||||
# Each item is a service with a corresponding state.
|
||||
SERVICE_TO_STATE = {
|
||||
SERVICE_TURN_ON: STATE_ON,
|
||||
SERVICE_TURN_OFF: STATE_OFF,
|
||||
SERVICE_ALARM_ARM_AWAY: STATE_ALARM_ARMED_AWAY,
|
||||
SERVICE_ALARM_ARM_HOME: STATE_ALARM_ARMED_HOME,
|
||||
SERVICE_ALARM_DISARM: STATE_ALARM_DISARMED,
|
||||
SERVICE_ALARM_TRIGGER: STATE_ALARM_TRIGGERED,
|
||||
SERVICE_LOCK: STATE_LOCKED,
|
||||
SERVICE_UNLOCK: STATE_UNLOCKED,
|
||||
SERVICE_OPEN_COVER: STATE_OPEN,
|
||||
SERVICE_CLOSE_COVER: STATE_CLOSED,
|
||||
}
|
||||
|
||||
|
||||
class AsyncTrackStates:
|
||||
"""
|
||||
|
@ -109,18 +64,6 @@ def get_changed_since(
|
|||
return [state for state in states if state.last_updated >= utc_point_in_time]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def reproduce_state(
|
||||
hass: HomeAssistantType,
|
||||
states: Union[State, Iterable[State]],
|
||||
blocking: bool = False,
|
||||
) -> None:
|
||||
"""Reproduce given state."""
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
async_reproduce_state(hass, states, blocking), hass.loop
|
||||
).result()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_reproduce_state(
|
||||
hass: HomeAssistantType,
|
||||
|
@ -149,16 +92,12 @@ async def async_reproduce_state(
|
|||
try:
|
||||
platform: Optional[ModuleType] = integration.get_platform("reproduce_state")
|
||||
except ImportError:
|
||||
platform = None
|
||||
_LOGGER.warning("Integration %s does not support reproduce state", domain)
|
||||
return
|
||||
|
||||
if platform:
|
||||
await platform.async_reproduce_states( # type: ignore
|
||||
hass, states_by_domain, context=context
|
||||
)
|
||||
else:
|
||||
await async_reproduce_state_legacy(
|
||||
hass, domain, states_by_domain, blocking=blocking, context=context
|
||||
)
|
||||
await platform.async_reproduce_states( # type: ignore
|
||||
hass, states_by_domain, context=context
|
||||
)
|
||||
|
||||
if to_call:
|
||||
# run all domains in parallel
|
||||
|
@ -167,75 +106,6 @@ async def async_reproduce_state(
|
|||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_reproduce_state_legacy(
|
||||
hass: HomeAssistantType,
|
||||
domain: str,
|
||||
states: Iterable[State],
|
||||
blocking: bool = False,
|
||||
context: Optional[Context] = None,
|
||||
) -> None:
|
||||
"""Reproduce given state."""
|
||||
to_call: Dict[Tuple[str, str], List[str]] = defaultdict(list)
|
||||
|
||||
if domain == GROUP_DOMAIN:
|
||||
service_domain = HASS_DOMAIN
|
||||
else:
|
||||
service_domain = domain
|
||||
|
||||
for state in states:
|
||||
|
||||
if hass.states.get(state.entity_id) is None:
|
||||
_LOGGER.warning(
|
||||
"reproduce_state: Unable to find entity %s", state.entity_id
|
||||
)
|
||||
continue
|
||||
|
||||
domain_services = hass.services.async_services().get(service_domain)
|
||||
|
||||
if not domain_services:
|
||||
_LOGGER.warning("reproduce_state: Unable to reproduce state %s (1)", state)
|
||||
continue
|
||||
|
||||
service = None
|
||||
for _service in domain_services.keys():
|
||||
if (
|
||||
_service in SERVICE_ATTRIBUTES
|
||||
and all(
|
||||
attr in state.attributes for attr in SERVICE_ATTRIBUTES[_service]
|
||||
)
|
||||
or _service in SERVICE_TO_STATE
|
||||
and SERVICE_TO_STATE[_service] == state.state
|
||||
):
|
||||
service = _service
|
||||
if (
|
||||
_service in SERVICE_TO_STATE
|
||||
and SERVICE_TO_STATE[_service] == state.state
|
||||
):
|
||||
break
|
||||
|
||||
if not service:
|
||||
_LOGGER.warning("reproduce_state: Unable to reproduce state %s (2)", state)
|
||||
continue
|
||||
|
||||
# We group service calls for entities by service call
|
||||
# json used to create a hashable version of dict with maybe lists in it
|
||||
key = (service, json.dumps(dict(state.attributes), sort_keys=True))
|
||||
to_call[key].append(state.entity_id)
|
||||
|
||||
domain_tasks: List[Awaitable[Optional[bool]]] = []
|
||||
for (service, service_data), entity_ids in to_call.items():
|
||||
data = json.loads(service_data)
|
||||
data[ATTR_ENTITY_ID] = entity_ids
|
||||
|
||||
domain_tasks.append(
|
||||
hass.services.async_call(service_domain, service, data, blocking, context)
|
||||
)
|
||||
|
||||
if domain_tasks:
|
||||
await asyncio.wait(domain_tasks)
|
||||
|
||||
|
||||
def state_as_number(state: State) -> float:
|
||||
"""
|
||||
Try to coerce our state to a number.
|
||||
|
|
Loading…
Add table
Reference in a new issue