From e136847b89c086076c3e36dfa8526255da72366e Mon Sep 17 00:00:00 2001 From: Cyrill Raccaud Date: Tue, 26 Mar 2024 10:17:25 +0100 Subject: [PATCH] Add more timestamp sensors to swiss_public_transport (#107916) * add more timestamp sensors * more generic definition for future sensors * add entity descriptor * use enable property to prevent sensors from getting added * set legacy attribute flag for first sensor * remove departure from extra attributes * remove breaking changes again and keep for next pr * fix multiline statements * outsource the multiline ifs into function --------- Co-authored-by: Joost Lekkerkerker --- .../swiss_public_transport/const.py | 2 + .../swiss_public_transport/coordinator.py | 82 +++++++++---------- .../swiss_public_transport/sensor.py | 80 ++++++++++++++---- .../swiss_public_transport/strings.json | 8 +- 4 files changed, 115 insertions(+), 57 deletions(-) diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index 6d9fb8bb960..6ae3cc9fd2f 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -7,6 +7,8 @@ CONF_START = "from" DEFAULT_NAME = "Next Destination" +SENSOR_CONNECTIONS_COUNT = 3 + PLACEHOLDERS = { "stationboard_url": "http://transport.opendata.ch/examples/stationboard.html", diff --git a/homeassistant/components/swiss_public_transport/coordinator.py b/homeassistant/components/swiss_public_transport/coordinator.py index d24dc85e3dc..7df593d5667 100644 --- a/homeassistant/components/swiss_public_transport/coordinator.py +++ b/homeassistant/components/swiss_public_transport/coordinator.py @@ -14,7 +14,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed import homeassistant.util.dt as dt_util -from .const import DOMAIN +from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT _LOGGER = logging.getLogger(__name__) @@ -23,8 +23,8 @@ class DataConnection(TypedDict): """A connection data class.""" departure: datetime | None - next_departure: str | None - next_on_departure: str | None + next_departure: datetime | None + next_on_departure: datetime | None duration: str platform: str remaining_time: str @@ -35,7 +35,9 @@ class DataConnection(TypedDict): delay: int -class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnection]): +class SwissPublicTransportDataUpdateCoordinator( + DataUpdateCoordinator[list[DataConnection]] +): """A SwissPublicTransport Data Update Coordinator.""" config_entry: ConfigEntry @@ -50,7 +52,22 @@ class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnec ) self._opendata = opendata - async def _async_update_data(self) -> DataConnection: + def remaining_time(self, departure) -> timedelta | None: + """Calculate the remaining time for the departure.""" + departure_datetime = dt_util.parse_datetime(departure) + + if departure_datetime: + return departure_datetime - dt_util.as_local(dt_util.utcnow()) + return None + + def nth_departure_time(self, i: int) -> datetime | None: + """Get nth departure time.""" + connections = self._opendata.connections + if len(connections) > i and connections[i] is not None: + return dt_util.parse_datetime(connections[i]["departure"]) + return None + + async def _async_update_data(self) -> list[DataConnection]: try: await self._opendata.async_get_data() except OpendataTransportError as e: @@ -59,41 +76,22 @@ class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnec ) raise UpdateFailed from e - departure_time = ( - dt_util.parse_datetime(self._opendata.connections[0]["departure"]) - if self._opendata.connections[0] is not None - else None - ) - next_departure_time = ( - dt_util.parse_datetime(self._opendata.connections[1]["departure"]) - if self._opendata.connections[1] is not None - else None - ) - next_on_departure_time = ( - dt_util.parse_datetime(self._opendata.connections[2]["departure"]) - if self._opendata.connections[2] is not None - else None - ) + connections = self._opendata.connections - if departure_time: - remaining_time = departure_time - dt_util.as_local(dt_util.utcnow()) - else: - remaining_time = None - - return DataConnection( - departure=departure_time, - next_departure=next_departure_time.isoformat() - if next_departure_time is not None - else None, - next_on_departure=next_on_departure_time.isoformat() - if next_on_departure_time is not None - else None, - train_number=self._opendata.connections[0]["number"], - platform=self._opendata.connections[0]["platform"], - transfers=self._opendata.connections[0]["transfers"], - duration=self._opendata.connections[0]["duration"], - start=self._opendata.from_name, - destination=self._opendata.to_name, - remaining_time=f"{remaining_time}", - delay=self._opendata.connections[0]["delay"], - ) + return [ + DataConnection( + departure=self.nth_departure_time(i), + next_departure=self.nth_departure_time(i + 1), + next_on_departure=self.nth_departure_time(i + 2), + train_number=connections[i]["number"], + platform=connections[i]["platform"], + transfers=connections[i]["transfers"], + duration=connections[i]["duration"], + start=self._opendata.from_name, + destination=self._opendata.to_name, + remaining_time=str(self.remaining_time(connections[i]["departure"])), + delay=connections[i]["delay"], + ) + for i in range(SENSOR_CONNECTIONS_COUNT) + if len(connections) > i and connections[i] is not None + ] diff --git a/homeassistant/components/swiss_public_transport/sensor.py b/homeassistant/components/swiss_public_transport/sensor.py index 4bca9aade60..7c712c8c189 100644 --- a/homeassistant/components/swiss_public_transport/sensor.py +++ b/homeassistant/components/swiss_public_transport/sensor.py @@ -2,6 +2,8 @@ from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass from datetime import datetime, timedelta import logging from typing import TYPE_CHECKING @@ -13,6 +15,7 @@ from homeassistant.components.sensor import ( PLATFORM_SCHEMA, SensorDeviceClass, SensorEntity, + SensorEntityDescription, ) from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import CONF_NAME @@ -25,8 +28,15 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.update_coordinator import CoordinatorEntity -from .const import CONF_DESTINATION, CONF_START, DEFAULT_NAME, DOMAIN, PLACEHOLDERS -from .coordinator import SwissPublicTransportDataUpdateCoordinator +from .const import ( + CONF_DESTINATION, + CONF_START, + DEFAULT_NAME, + DOMAIN, + PLACEHOLDERS, + SENSOR_CONNECTIONS_COUNT, +) +from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) @@ -41,6 +51,33 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( ) +@dataclass(kw_only=True, frozen=True) +class SwissPublicTransportSensorEntityDescription(SensorEntityDescription): + """Describes swiss public transport sensor entity.""" + + exists_fn: Callable[[DataConnection], bool] + value_fn: Callable[[DataConnection], datetime | None] + + index: int + has_legacy_attributes: bool + + +SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = ( + *[ + SwissPublicTransportSensorEntityDescription( + key=f"departure{i or ''}", + translation_key=f"departure{i}", + device_class=SensorDeviceClass.TIMESTAMP, + has_legacy_attributes=i == 0, + value_fn=lambda data_connection: data_connection["departure"], + exists_fn=lambda data_connection: data_connection is not None, + index=i, + ) + for i in range(SENSOR_CONNECTIONS_COUNT) + ], +) + + async def async_setup_entry( hass: core.HomeAssistant, config_entry: config_entries.ConfigEntry, @@ -55,7 +92,8 @@ async def async_setup_entry( assert unique_id async_add_entities( - [SwissPublicTransportSensor(coordinator, unique_id)], + SwissPublicTransportSensor(coordinator, description, unique_id) + for description in SENSORS ) @@ -108,34 +146,51 @@ class SwissPublicTransportSensor( ): """Implementation of a Swiss public transport sensor.""" + entity_description: SwissPublicTransportSensorEntityDescription _attr_attribution = "Data provided by transport.opendata.ch" _attr_has_entity_name = True - _attr_translation_key = "departure" - _attr_device_class = SensorDeviceClass.TIMESTAMP def __init__( self, coordinator: SwissPublicTransportDataUpdateCoordinator, + entity_description: SwissPublicTransportSensorEntityDescription, unique_id: str, ) -> None: """Initialize the sensor.""" super().__init__(coordinator) - self._attr_unique_id = f"{unique_id}_departure" + self.entity_description = entity_description + self._attr_unique_id = f"{unique_id}_{entity_description.key}" self._attr_device_info = DeviceInfo( identifiers={(DOMAIN, unique_id)}, manufacturer="Opendata.ch", entry_type=DeviceEntryType.SERVICE, ) + @property + def enabled(self) -> bool: + """Enable the sensor if data is available.""" + return self.entity_description.exists_fn( + self.coordinator.data[self.entity_description.index] + ) + + @property + def native_value(self) -> datetime | None: + """Return the state of the sensor.""" + return self.entity_description.value_fn( + self.coordinator.data[self.entity_description.index] + ) + async def async_added_to_hass(self) -> None: """Prepare the extra attributes at start.""" - self._async_update_attrs() + if self.entity_description.has_legacy_attributes: + self._async_update_attrs() await super().async_added_to_hass() @callback def _handle_coordinator_update(self) -> None: """Handle the state update and prepare the extra state attributes.""" - self._async_update_attrs() + if self.entity_description.has_legacy_attributes: + self._async_update_attrs() return super()._handle_coordinator_update() @callback @@ -143,11 +198,8 @@ class SwissPublicTransportSensor( """Update the extra state attributes based on the coordinator data.""" self._attr_extra_state_attributes = { key: value - for key, value in self.coordinator.data.items() + for key, value in self.coordinator.data[ + self.entity_description.index + ].items() if key not in {"departure"} } - - @property - def native_value(self) -> datetime | None: - """Return the state of the sensor.""" - return self.coordinator.data["departure"] diff --git a/homeassistant/components/swiss_public_transport/strings.json b/homeassistant/components/swiss_public_transport/strings.json index 6d0eb53ad11..c0e88f08b8d 100644 --- a/homeassistant/components/swiss_public_transport/strings.json +++ b/homeassistant/components/swiss_public_transport/strings.json @@ -24,8 +24,14 @@ }, "entity": { "sensor": { - "departure": { + "departure0": { "name": "Departure" + }, + "departure1": { + "name": "Departure +1" + }, + "departure2": { + "name": "Departure +2" } } },