Limit rainbird aiohttp client session to a single connection (#112146)

Limit rainbird to a single open http connection
This commit is contained in:
Allen Porter 2024-03-03 16:54:05 -08:00 committed by GitHub
parent f9e00ed45b
commit 5cb5a1141f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 47 additions and 6 deletions

View file

@ -11,11 +11,10 @@ from homeassistant.const import CONF_HOST, CONF_MAC, CONF_PASSWORD, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.device_registry import format_mac
from .const import CONF_SERIAL_NUMBER from .const import CONF_SERIAL_NUMBER
from .coordinator import RainbirdData from .coordinator import RainbirdData, async_create_clientsession
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -36,9 +35,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {}) hass.data.setdefault(DOMAIN, {})
clientsession = async_create_clientsession()
entry.async_on_unload(clientsession.close)
controller = AsyncRainbirdController( controller = AsyncRainbirdController(
AsyncRainbirdClient( AsyncRainbirdClient(
async_get_clientsession(hass), clientsession,
entry.data[CONF_HOST], entry.data[CONF_HOST],
entry.data[CONF_PASSWORD], entry.data[CONF_PASSWORD],
) )

View file

@ -23,7 +23,6 @@ from homeassistant.config_entries import (
from homeassistant.const import CONF_HOST, CONF_MAC, CONF_PASSWORD from homeassistant.const import CONF_HOST, CONF_MAC, CONF_PASSWORD
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv, selector from homeassistant.helpers import config_validation as cv, selector
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.device_registry import format_mac
from .const import ( from .const import (
@ -33,6 +32,7 @@ from .const import (
DOMAIN, DOMAIN,
TIMEOUT_SECONDS, TIMEOUT_SECONDS,
) )
from .coordinator import async_create_clientsession
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -104,9 +104,10 @@ class RainbirdConfigFlowHandler(ConfigFlow, domain=DOMAIN):
Raises a ConfigFlowError on failure. Raises a ConfigFlowError on failure.
""" """
clientsession = async_create_clientsession()
controller = AsyncRainbirdController( controller = AsyncRainbirdController(
AsyncRainbirdClient( AsyncRainbirdClient(
async_get_clientsession(self.hass), clientsession,
host, host,
password, password,
) )
@ -127,6 +128,8 @@ class RainbirdConfigFlowHandler(ConfigFlow, domain=DOMAIN):
f"Error connecting to Rain Bird controller: {str(err)}", f"Error connecting to Rain Bird controller: {str(err)}",
"cannot_connect", "cannot_connect",
) from err ) from err
finally:
await clientsession.close()
async def async_finish( async def async_finish(
self, self,

View file

@ -9,6 +9,7 @@ from functools import cached_property
import logging import logging
from typing import TypeVar from typing import TypeVar
import aiohttp
from pyrainbird.async_client import ( from pyrainbird.async_client import (
AsyncRainbirdController, AsyncRainbirdController,
RainbirdApiException, RainbirdApiException,
@ -28,6 +29,9 @@ UPDATE_INTERVAL = datetime.timedelta(minutes=1)
# changes, so we refresh it less often. # changes, so we refresh it less often.
CALENDAR_UPDATE_INTERVAL = datetime.timedelta(minutes=15) CALENDAR_UPDATE_INTERVAL = datetime.timedelta(minutes=15)
# Rainbird devices can only accept a single request at a time
CONECTION_LIMIT = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_T = TypeVar("_T") _T = TypeVar("_T")
@ -43,6 +47,13 @@ class RainbirdDeviceState:
rain_delay: int rain_delay: int
def async_create_clientsession() -> aiohttp.ClientSession:
"""Create a rainbird async_create_clientsession with a connection limit."""
return aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=CONECTION_LIMIT),
)
class RainbirdUpdateCoordinator(DataUpdateCoordinator[RainbirdDeviceState]): class RainbirdUpdateCoordinator(DataUpdateCoordinator[RainbirdDeviceState]):
"""Coordinator for rainbird API calls.""" """Coordinator for rainbird API calls."""

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus from http import HTTPStatus
import json import json
from typing import Any from typing import Any
@ -15,7 +16,7 @@ from homeassistant.components.rainbird.const import (
ATTR_DURATION, ATTR_DURATION,
DEFAULT_TRIGGER_TIME_MINUTES, DEFAULT_TRIGGER_TIME_MINUTES,
) )
from homeassistant.const import Platform from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -155,6 +156,31 @@ def setup_platforms(
yield yield
@pytest.fixture(autouse=True)
def aioclient_mock(hass: HomeAssistant) -> Generator[AiohttpClientMocker, None, None]:
"""Context manager to mock aiohttp client."""
mocker = AiohttpClientMocker()
def create_session():
session = mocker.create_session(hass.loop)
async def close_session(event):
"""Close session."""
await session.close()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, close_session)
return session
with patch(
"homeassistant.components.rainbird.async_create_clientsession",
side_effect=create_session,
), patch(
"homeassistant.components.rainbird.config_flow.async_create_clientsession",
side_effect=create_session,
):
yield mocker
def rainbird_json_response(result: dict[str, str]) -> bytes: def rainbird_json_response(result: dict[str, str]) -> bytes:
"""Create a fake API response.""" """Create a fake API response."""
return encryption.encrypt( return encryption.encrypt(