Make reproduce state use platform instead of rely on function (#25856)

* Make reproduce state use platform instead of rely on function

* Fix types

* address comment Martin.
This commit is contained in:
Paulus Schoutsen 2019-08-11 20:03:21 -07:00 committed by GitHub
parent ab7db5fbd0
commit cf90e49b50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 26 additions and 21 deletions

View file

@ -4,7 +4,7 @@ import datetime as dt
import json
import logging
from collections import defaultdict
from types import TracebackType
from types import ModuleType, TracebackType
from typing import ( # noqa: F401 pylint: disable=unused-import
Awaitable,
Dict,
@ -16,7 +16,7 @@ from typing import ( # noqa: F401 pylint: disable=unused-import
Union,
)
from homeassistant.loader import bind_hass
from homeassistant.loader import bind_hass, async_get_integration, IntegrationNotFound
import homeassistant.util.dt as dt_util
from homeassistant.components.notify import ATTR_MESSAGE, SERVICE_NOTIFY
from homeassistant.components.sun import STATE_ABOVE_HORIZON, STATE_BELOW_HORIZON
@ -152,13 +152,27 @@ async def async_reproduce_state(
for state in states:
to_call[state.domain].append(state)
async def worker(domain: str, data: List[State]) -> None:
component = getattr(hass.components, domain)
if hasattr(component, "async_reproduce_states"):
await component.async_reproduce_states(data, context=context)
async def worker(domain: str, states_by_domain: List[State]) -> None:
try:
integration = await async_get_integration(hass, domain)
except IntegrationNotFound:
_LOGGER.warning(
"Trying to reproduce state for unknown integration: %s", domain
)
return
try:
platform: Optional[ModuleType] = integration.get_platform("reproduce_state")
except ImportError:
platform = None
if platform:
await platform.async_reproduce_states( # type: ignore
hass, states_by_domain, context=context
)
else:
await async_reproduce_state_legacy(
hass, domain, data, blocking=blocking, context=context
hass, domain, states_by_domain, blocking=blocking, context=context
)
if to_call: