diff --git a/homeassistant/helpers/singleton.py b/homeassistant/helpers/singleton.py index 82b666be40a..ab4d12dc1cc 100644 --- a/homeassistant/helpers/singleton.py +++ b/homeassistant/helpers/singleton.py @@ -1,14 +1,14 @@ """Helper to help coordinating calls.""" import asyncio import functools -from typing import Awaitable, Callable, TypeVar, cast +from typing import Callable, Optional, TypeVar, cast from homeassistant.core import HomeAssistant from homeassistant.loader import bind_hass T = TypeVar("T") -FUNC = Callable[[HomeAssistant], Awaitable[T]] +FUNC = Callable[[HomeAssistant], T] def singleton(data_key: str) -> Callable[[FUNC], FUNC]: @@ -19,10 +19,21 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]: def wrapper(func: FUNC) -> FUNC: """Wrap a function with caching logic.""" + if not asyncio.iscoroutinefunction(func): + + @bind_hass + @functools.wraps(func) + def wrapped(hass: HomeAssistant) -> T: + obj: Optional[T] = hass.data.get(data_key) + if obj is None: + obj = hass.data[data_key] = func(hass) + return obj + + return wrapped @bind_hass @functools.wraps(func) - async def wrapped(hass: HomeAssistant) -> T: + async def async_wrapped(hass: HomeAssistant) -> T: obj_or_evt = hass.data.get(data_key) if not obj_or_evt: @@ -41,6 +52,6 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]: return cast(T, obj_or_evt) - return wrapped + return async_wrapped return wrapper diff --git a/tests/helpers/test_singleton.py b/tests/helpers/test_singleton.py new file mode 100644 index 00000000000..03230a02ab8 --- /dev/null +++ b/tests/helpers/test_singleton.py @@ -0,0 +1,40 @@ +"""Test singleton helper.""" +import pytest + +from homeassistant.helpers import singleton + +from tests.async_mock import Mock + + +@pytest.fixture +def mock_hass(): + """Mock hass fixture.""" + return Mock(data={}) + + +async def test_singleton_async(mock_hass): + """Test singleton with async function.""" + + @singleton.singleton("test_key") + async def something(hass): + return object() + + result1 = await something(mock_hass) + result2 = await something(mock_hass) + assert result1 is result2 + assert "test_key" in mock_hass.data + assert mock_hass.data["test_key"] is result1 + + +def test_singleton(mock_hass): + """Test singleton with function.""" + + @singleton.singleton("test_key") + def something(hass): + return object() + + result1 = something(mock_hass) + result2 = something(mock_hass) + assert result1 is result2 + assert "test_key" in mock_hass.data + assert mock_hass.data["test_key"] is result1