Avoid blocking I/O in gpsd (#122176)

This commit is contained in:
Jan Rieger 2024-07-19 18:25:07 +02:00 committed by GitHub
parent 72d37036b9
commit 12ec66c2c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 54 additions and 22 deletions

View file

@ -2,19 +2,45 @@
from __future__ import annotations
from gps3.agps3threaded import AGPS3mechanism
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.const import CONF_HOST, CONF_PORT, Platform
from homeassistant.core import HomeAssistant
PLATFORMS: list[Platform] = [Platform.SENSOR]
type GPSDConfigEntry = ConfigEntry[AGPS3mechanism]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: GPSDConfigEntry) -> bool:
"""Set up GPSD from a config entry."""
agps_thread = AGPS3mechanism()
entry.runtime_data = agps_thread
def setup_agps() -> None:
host = entry.data.get(CONF_HOST)
port = entry.data.get(CONF_PORT)
agps_thread.stream_data(host, port)
agps_thread.run_thread()
await hass.async_add_executor_job(setup_agps)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: GPSDConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
agps_thread = entry.runtime_data
await hass.async_add_executor_job(
lambda: agps_thread.stream_data(
host=entry.data.get(CONF_HOST),
port=entry.data.get(CONF_PORT),
enable=False,
)
)
return unload_ok

View file

@ -27,6 +27,18 @@ class GPSDConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
@staticmethod
def test_connection(host: str, port: int) -> bool:
"""Test socket connection."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, port))
sock.shutdown(2)
except OSError:
return False
else:
return True
async def async_step_import(self, import_data: dict[str, Any]) -> ConfigFlowResult:
"""Import a config entry from configuration.yaml."""
return await self.async_step_user(import_data)
@ -38,11 +50,11 @@ class GPSDConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None:
self._async_abort_entries_match(user_input)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect((user_input[CONF_HOST], user_input[CONF_PORT]))
sock.shutdown(2)
except OSError:
connected = await self.hass.async_add_executor_job(
self.test_connection, user_input[CONF_HOST], user_input[CONF_PORT]
)
if not connected:
return self.async_abort(reason="cannot_connect")
port = ""

View file

@ -20,7 +20,7 @@ from homeassistant.components.sensor import (
SensorEntity,
SensorEntityDescription,
)
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import (
ATTR_LATITUDE,
ATTR_LONGITUDE,
@ -37,6 +37,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GPSDConfigEntry
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -81,15 +82,14 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend(
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
config_entry: GPSDConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the GPSD component."""
async_add_entities(
[
GpsdSensor(
config_entry.data[CONF_HOST],
config_entry.data[CONF_PORT],
config_entry.runtime_data,
config_entry.entry_id,
description,
)
@ -135,8 +135,7 @@ class GpsdSensor(SensorEntity):
def __init__(
self,
host: str,
port: int,
agps_thread: AGPS3mechanism,
unique_id: str,
description: GpsdSensorDescription,
) -> None:
@ -148,9 +147,7 @@ class GpsdSensor(SensorEntity):
)
self._attr_unique_id = f"{unique_id}-{self.entity_description.key}"
self.agps_thread = AGPS3mechanism()
self.agps_thread.stream_data(host=host, port=port)
self.agps_thread.run_thread()
self.agps_thread = agps_thread
@property
def native_value(self) -> str | None:

View file

@ -43,10 +43,7 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
async def test_connection_error(hass: HomeAssistant) -> None:
"""Test connection to host error."""
with patch("socket.socket") as mock_socket:
mock_connect = mock_socket.return_value.connect
mock_connect.side_effect = OSError
with patch("socket.socket", side_effect=OSError):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_USER},